diff options
author | Johannes Stoelp <johannes.stoelp@gmail.com> | 2021-09-25 00:48:45 +0200 |
---|---|---|
committer | Johannes Stoelp <johannes.stoelp@gmail.com> | 2021-09-25 00:48:45 +0200 |
commit | 6eb6ad9f574c783d471f6a863299af25b6f5a8c7 (patch) | |
tree | 38c087654f2c703d7d4c6afbf342aa9dd65557c9 /src/llvm | |
parent | 425ba77074347f71283f75839224f78bd94f2e10 (diff) | |
download | llvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.tar.gz llvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.zip |
ch4: added jit
Diffstat (limited to 'src/llvm')
-rw-r--r-- | src/llvm/lljit.rs | 151 | ||||
-rw-r--r-- | src/llvm/mod.rs | 44 |
2 files changed, 194 insertions, 1 deletions
diff --git a/src/llvm/lljit.rs b/src/llvm/lljit.rs new file mode 100644 index 0000000..88c7059 --- /dev/null +++ b/src/llvm/lljit.rs @@ -0,0 +1,151 @@ +use llvm_sys::orc2::{ + lljit::{ + LLVMOrcCreateLLJIT, LLVMOrcLLJITAddLLVMIRModuleWithRT, LLVMOrcLLJITGetGlobalPrefix, + LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup, LLVMOrcLLJITRef, + }, + LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess, LLVMOrcDefinitionGeneratorRef, + LLVMOrcJITDylibAddGenerator, LLVMOrcJITDylibCreateResourceTracker, LLVMOrcJITDylibRef, + LLVMOrcReleaseResourceTracker, LLVMOrcResourceTrackerRef, LLVMOrcResourceTrackerRemove, +}; + +use std::convert::TryFrom; +use std::marker::PhantomData; + +use super::{Error, Module}; +use crate::SmallCStr; + +/// Marker trait to constrain function signatures that can be looked up in the JIT. +pub trait JitFn {} + +impl JitFn for unsafe extern "C" fn() -> f64 {} + +pub struct LLJit { + jit: LLVMOrcLLJITRef, + dylib: LLVMOrcJITDylibRef, +} + +impl LLJit { + /// Create a new LLJit instance. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or an error. + pub fn new() -> LLJit { + let (jit, dylib) = unsafe { + let mut jit = std::ptr::null_mut(); + let err = LLVMOrcCreateLLJIT( + &mut jit as _, + std::ptr::null_mut(), /* builder: nullptr -> default */ + ); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + let dylib = LLVMOrcLLJITGetMainJITDylib(jit); + assert!(!dylib.is_null()); + + (jit, dylib) + }; + + LLJit { jit, dylib } + } + + /// Add an LLVM IR module to the JIT. Return a [`ResourceTracker`], which when dropped, will + /// remove the code of the LLVM IR module from the JIT. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or an error. + pub fn add_module(&self, module: Module) -> ResourceTracker<'_> { + let tsmod = module.into_raw_thread_safe_module(); + + let rt = unsafe { + let rt = LLVMOrcJITDylibCreateResourceTracker(self.dylib); + let err = LLVMOrcLLJITAddLLVMIRModuleWithRT(self.jit, rt, tsmod); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + rt + }; + + ResourceTracker::new(rt) + } + + /// Find the symbol with the name `sym` in the JIT. + /// + /// # Panics + /// + /// Panics if the symbol is not found in the JIT. + pub fn find_symbol<F: JitFn>(&self, sym: &str) -> F { + let sym = + SmallCStr::try_from(sym).expect("Failed to convert 'sym' argument to small C string!"); + + unsafe { + let mut addr = 0u64; + let err = LLVMOrcLLJITLookup(self.jit, &mut addr as _, sym.as_ptr()); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + debug_assert_eq!(core::mem::size_of_val(&addr), core::mem::size_of::<F>()); + std::mem::transmute_copy(&addr) + } + } + + /// Enable lookup of dynamic symbols available in the current process from the JIT. + /// + /// # Panics + /// + /// Panics if LLVM API returns an error. + pub fn enable_process_symbols(&self) { + unsafe { + let mut proc_syms_gen: LLVMOrcDefinitionGeneratorRef = std::ptr::null_mut(); + let err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess( + &mut proc_syms_gen as _, + self.global_prefix(), + None, /* filter */ + std::ptr::null_mut(), /* filter ctx */ + ); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + LLVMOrcJITDylibAddGenerator(self.dylib, proc_syms_gen); + } + } + + /// Return the global prefix character according to the LLJITs data layout. + fn global_prefix(&self) -> libc::c_char { + unsafe { LLVMOrcLLJITGetGlobalPrefix(self.jit) } + } +} + +/// A resource handle to code added to an [`LLJit`] instance. When a `ResourceTracker` handle is +/// dropped, the code corresponding to the handle will be removed from the JIT. +pub struct ResourceTracker<'jit>(LLVMOrcResourceTrackerRef, PhantomData<&'jit ()>); + +impl<'jit> ResourceTracker<'jit> { + fn new(rt: LLVMOrcResourceTrackerRef) -> ResourceTracker<'jit> { + assert!(!rt.is_null()); + ResourceTracker(rt, PhantomData) + } +} + +impl Drop for ResourceTracker<'_> { + fn drop(&mut self) { + unsafe { + let err = LLVMOrcResourceTrackerRemove(self.0); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + LLVMOrcReleaseResourceTracker(self.0); + }; + } +} diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 01ed3f2..16e6bfd 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -8,17 +8,28 @@ //! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM //! API calls. -use llvm_sys::{core::LLVMShutdown, prelude::LLVMBasicBlockRef}; +use llvm_sys::{ + core::LLVMShutdown, + error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage}, + prelude::LLVMBasicBlockRef, + target::{ + LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter, + LLVM_InitializeNativeTarget, + }, +}; +use std::ffi::CStr; use std::marker::PhantomData; mod builder; +mod lljit; mod module; mod pass_manager; mod type_; mod value; pub use builder::IRBuilder; +pub use lljit::{LLJit, ResourceTracker}; pub use module::Module; pub use pass_manager::FunctionPassManager; pub use type_::Type; @@ -28,6 +39,37 @@ pub use value::{FnValue, Value}; #[derive(Copy, Clone)] pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); +struct Error<'llvm>(&'llvm mut libc::c_char); + +impl<'llvm> Error<'llvm> { + fn from(err: LLVMErrorRef) -> Option<Error<'llvm>> { + (!err.is_null()).then(|| Error(unsafe { &mut *LLVMGetErrorMessage(err) })) + } + + fn as_str(&self) -> &str { + unsafe { CStr::from_ptr(self.0) } + .to_str() + .expect("Expected valid UTF8 string from LLVM API") + } +} + +impl Drop for Error<'_> { + fn drop(&mut self) { + unsafe { + LLVMDisposeErrorMessage(self.0 as *mut libc::c_char); + } + } +} + +/// Initialize native target for corresponding to host (useful for jitting). +pub fn initialize_native_taget() { + unsafe { + assert_eq!(LLVM_InitializeNativeTarget(), 0); + assert_eq!(LLVM_InitializeNativeAsmParser(), 0); + assert_eq!(LLVM_InitializeNativeAsmPrinter(), 0); + } +} + /// Deallocate and destroy all "ManagedStatic" variables. pub fn shutdown() { unsafe { |