aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/llvm
diff options
context:
space:
mode:
authorJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-25 00:48:45 +0200
committerJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-25 00:48:45 +0200
commit6eb6ad9f574c783d471f6a863299af25b6f5a8c7 (patch)
tree38c087654f2c703d7d4c6afbf342aa9dd65557c9 /src/llvm
parent425ba77074347f71283f75839224f78bd94f2e10 (diff)
downloadllvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.tar.gz
llvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.zip
ch4: added jit
Diffstat (limited to 'src/llvm')
-rw-r--r--src/llvm/lljit.rs151
-rw-r--r--src/llvm/mod.rs44
2 files changed, 194 insertions, 1 deletions
diff --git a/src/llvm/lljit.rs b/src/llvm/lljit.rs
new file mode 100644
index 0000000..88c7059
--- /dev/null
+++ b/src/llvm/lljit.rs
@@ -0,0 +1,151 @@
+use llvm_sys::orc2::{
+ lljit::{
+ LLVMOrcCreateLLJIT, LLVMOrcLLJITAddLLVMIRModuleWithRT, LLVMOrcLLJITGetGlobalPrefix,
+ LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup, LLVMOrcLLJITRef,
+ },
+ LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess, LLVMOrcDefinitionGeneratorRef,
+ LLVMOrcJITDylibAddGenerator, LLVMOrcJITDylibCreateResourceTracker, LLVMOrcJITDylibRef,
+ LLVMOrcReleaseResourceTracker, LLVMOrcResourceTrackerRef, LLVMOrcResourceTrackerRemove,
+};
+
+use std::convert::TryFrom;
+use std::marker::PhantomData;
+
+use super::{Error, Module};
+use crate::SmallCStr;
+
+/// Marker trait to constrain function signatures that can be looked up in the JIT.
+pub trait JitFn {}
+
+impl JitFn for unsafe extern "C" fn() -> f64 {}
+
+pub struct LLJit {
+ jit: LLVMOrcLLJITRef,
+ dylib: LLVMOrcJITDylibRef,
+}
+
+impl LLJit {
+ /// Create a new LLJit instance.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer or an error.
+ pub fn new() -> LLJit {
+ let (jit, dylib) = unsafe {
+ let mut jit = std::ptr::null_mut();
+ let err = LLVMOrcCreateLLJIT(
+ &mut jit as _,
+ std::ptr::null_mut(), /* builder: nullptr -> default */
+ );
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ let dylib = LLVMOrcLLJITGetMainJITDylib(jit);
+ assert!(!dylib.is_null());
+
+ (jit, dylib)
+ };
+
+ LLJit { jit, dylib }
+ }
+
+ /// Add an LLVM IR module to the JIT. Return a [`ResourceTracker`], which when dropped, will
+ /// remove the code of the LLVM IR module from the JIT.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer or an error.
+ pub fn add_module(&self, module: Module) -> ResourceTracker<'_> {
+ let tsmod = module.into_raw_thread_safe_module();
+
+ let rt = unsafe {
+ let rt = LLVMOrcJITDylibCreateResourceTracker(self.dylib);
+ let err = LLVMOrcLLJITAddLLVMIRModuleWithRT(self.jit, rt, tsmod);
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ rt
+ };
+
+ ResourceTracker::new(rt)
+ }
+
+ /// Find the symbol with the name `sym` in the JIT.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the symbol is not found in the JIT.
+ pub fn find_symbol<F: JitFn>(&self, sym: &str) -> F {
+ let sym =
+ SmallCStr::try_from(sym).expect("Failed to convert 'sym' argument to small C string!");
+
+ unsafe {
+ let mut addr = 0u64;
+ let err = LLVMOrcLLJITLookup(self.jit, &mut addr as _, sym.as_ptr());
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ debug_assert_eq!(core::mem::size_of_val(&addr), core::mem::size_of::<F>());
+ std::mem::transmute_copy(&addr)
+ }
+ }
+
+ /// Enable lookup of dynamic symbols available in the current process from the JIT.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns an error.
+ pub fn enable_process_symbols(&self) {
+ unsafe {
+ let mut proc_syms_gen: LLVMOrcDefinitionGeneratorRef = std::ptr::null_mut();
+ let err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
+ &mut proc_syms_gen as _,
+ self.global_prefix(),
+ None, /* filter */
+ std::ptr::null_mut(), /* filter ctx */
+ );
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ LLVMOrcJITDylibAddGenerator(self.dylib, proc_syms_gen);
+ }
+ }
+
+ /// Return the global prefix character according to the LLJITs data layout.
+ fn global_prefix(&self) -> libc::c_char {
+ unsafe { LLVMOrcLLJITGetGlobalPrefix(self.jit) }
+ }
+}
+
+/// A resource handle to code added to an [`LLJit`] instance. When a `ResourceTracker` handle is
+/// dropped, the code corresponding to the handle will be removed from the JIT.
+pub struct ResourceTracker<'jit>(LLVMOrcResourceTrackerRef, PhantomData<&'jit ()>);
+
+impl<'jit> ResourceTracker<'jit> {
+ fn new(rt: LLVMOrcResourceTrackerRef) -> ResourceTracker<'jit> {
+ assert!(!rt.is_null());
+ ResourceTracker(rt, PhantomData)
+ }
+}
+
+impl Drop for ResourceTracker<'_> {
+ fn drop(&mut self) {
+ unsafe {
+ let err = LLVMOrcResourceTrackerRemove(self.0);
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ LLVMOrcReleaseResourceTracker(self.0);
+ };
+ }
+}
diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs
index 01ed3f2..16e6bfd 100644
--- a/src/llvm/mod.rs
+++ b/src/llvm/mod.rs
@@ -8,17 +8,28 @@
//! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM
//! API calls.
-use llvm_sys::{core::LLVMShutdown, prelude::LLVMBasicBlockRef};
+use llvm_sys::{
+ core::LLVMShutdown,
+ error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage},
+ prelude::LLVMBasicBlockRef,
+ target::{
+ LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter,
+ LLVM_InitializeNativeTarget,
+ },
+};
+use std::ffi::CStr;
use std::marker::PhantomData;
mod builder;
+mod lljit;
mod module;
mod pass_manager;
mod type_;
mod value;
pub use builder::IRBuilder;
+pub use lljit::{LLJit, ResourceTracker};
pub use module::Module;
pub use pass_manager::FunctionPassManager;
pub use type_::Type;
@@ -28,6 +39,37 @@ pub use value::{FnValue, Value};
#[derive(Copy, Clone)]
pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>);
+struct Error<'llvm>(&'llvm mut libc::c_char);
+
+impl<'llvm> Error<'llvm> {
+ fn from(err: LLVMErrorRef) -> Option<Error<'llvm>> {
+ (!err.is_null()).then(|| Error(unsafe { &mut *LLVMGetErrorMessage(err) }))
+ }
+
+ fn as_str(&self) -> &str {
+ unsafe { CStr::from_ptr(self.0) }
+ .to_str()
+ .expect("Expected valid UTF8 string from LLVM API")
+ }
+}
+
+impl Drop for Error<'_> {
+ fn drop(&mut self) {
+ unsafe {
+ LLVMDisposeErrorMessage(self.0 as *mut libc::c_char);
+ }
+ }
+}
+
+/// Initialize native target for corresponding to host (useful for jitting).
+pub fn initialize_native_taget() {
+ unsafe {
+ assert_eq!(LLVM_InitializeNativeTarget(), 0);
+ assert_eq!(LLVM_InitializeNativeAsmParser(), 0);
+ assert_eq!(LLVM_InitializeNativeAsmPrinter(), 0);
+ }
+}
+
/// Deallocate and destroy all "ManagedStatic" variables.
pub fn shutdown() {
unsafe {