aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-25 00:48:45 +0200
committerJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-25 00:48:45 +0200
commit6eb6ad9f574c783d471f6a863299af25b6f5a8c7 (patch)
tree38c087654f2c703d7d4c6afbf342aa9dd65557c9
parent425ba77074347f71283f75839224f78bd94f2e10 (diff)
downloadllvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.tar.gz
llvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.zip
ch4: added jit
-rw-r--r--src/codegen.rs35
-rw-r--r--src/llvm/lljit.rs151
-rw-r--r--src/llvm/mod.rs44
-rw-r--r--src/main.rs78
-rw-r--r--src/parser.rs4
5 files changed, 292 insertions, 20 deletions
diff --git a/src/codegen.rs b/src/codegen.rs
index 08c3039..61634ad 100644
--- a/src/codegen.rs
+++ b/src/codegen.rs
@@ -11,18 +11,21 @@ pub struct Codegen<'llvm, 'a> {
module: &'llvm Module,
builder: &'a IRBuilder<'llvm>,
fpm: &'a FunctionPassManager<'llvm>,
+ fn_protos: &'a mut HashMap<String, PrototypeAST>,
}
impl<'llvm, 'a> Codegen<'llvm, 'a> {
/// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`.
pub fn compile(
module: &'llvm Module,
+ fn_protos: &mut HashMap<String, PrototypeAST>,
compilee: Either<&PrototypeAST, &FunctionAST>,
) -> CodegenResult<FnValue<'llvm>> {
- let cg = Codegen {
+ let mut cg = Codegen {
module,
builder: &IRBuilder::with_ctx(module),
fpm: &FunctionPassManager::with_ctx(module),
+ fn_protos,
};
let mut variables = HashMap::new();
@@ -59,7 +62,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
_ => Err("invalid binary operator".into()),
}
}
- ExprAST::Call(callee, args) => match self.module.get_fn(callee) {
+ ExprAST::Call(callee, args) => match self.get_function(callee) {
Some(callee) => {
if callee.args() != args.len() {
return Err("Incorrect # arguments passed".into());
@@ -99,14 +102,16 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
}
fn codegen_function(
- &self,
+ &mut self,
FunctionAST(proto, body): &FunctionAST,
named_values: &mut HashMap<&'llvm str, Value<'llvm>>,
) -> CodegenResult<FnValue<'llvm>> {
- let the_function = match self.module.get_fn(&proto.0) {
- Some(f) => f,
- None => self.codegen_prototype(proto),
- };
+ // Insert the function prototype into the `fn_protos` map to keep track for re-generating
+ // declarations in other modules.
+ self.fn_protos.insert(proto.0.clone(), proto.clone());
+
+ let the_function = self.get_function(&proto.0)
+ .expect("If proto not already generated, get_function will do for us since we updated fn_protos before-hand!");
if the_function.basic_blocks() > 0 {
return Err("Function cannot be redefined.".into());
@@ -138,4 +143,20 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
todo!("Failed to codegen function body, erase from module!");
}
}
+
+ /// Lookup function with `name` in the LLVM module and return the corresponding value reference.
+ /// If the function is not available in the module, check if the prototype is known and codegen
+ /// it.
+ /// Return [`None`] if the prototype is not known.
+ fn get_function(&self, name: &str) -> Option<FnValue<'llvm>> {
+ let callee = match self.module.get_fn(name) {
+ Some(callee) => callee,
+ None => {
+ let proto = self.fn_protos.get(name)?;
+ self.codegen_prototype(proto)
+ }
+ };
+
+ Some(callee)
+ }
}
diff --git a/src/llvm/lljit.rs b/src/llvm/lljit.rs
new file mode 100644
index 0000000..88c7059
--- /dev/null
+++ b/src/llvm/lljit.rs
@@ -0,0 +1,151 @@
+use llvm_sys::orc2::{
+ lljit::{
+ LLVMOrcCreateLLJIT, LLVMOrcLLJITAddLLVMIRModuleWithRT, LLVMOrcLLJITGetGlobalPrefix,
+ LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup, LLVMOrcLLJITRef,
+ },
+ LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess, LLVMOrcDefinitionGeneratorRef,
+ LLVMOrcJITDylibAddGenerator, LLVMOrcJITDylibCreateResourceTracker, LLVMOrcJITDylibRef,
+ LLVMOrcReleaseResourceTracker, LLVMOrcResourceTrackerRef, LLVMOrcResourceTrackerRemove,
+};
+
+use std::convert::TryFrom;
+use std::marker::PhantomData;
+
+use super::{Error, Module};
+use crate::SmallCStr;
+
+/// Marker trait to constrain function signatures that can be looked up in the JIT.
+pub trait JitFn {}
+
+impl JitFn for unsafe extern "C" fn() -> f64 {}
+
+pub struct LLJit {
+ jit: LLVMOrcLLJITRef,
+ dylib: LLVMOrcJITDylibRef,
+}
+
+impl LLJit {
+ /// Create a new LLJit instance.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer or an error.
+ pub fn new() -> LLJit {
+ let (jit, dylib) = unsafe {
+ let mut jit = std::ptr::null_mut();
+ let err = LLVMOrcCreateLLJIT(
+ &mut jit as _,
+ std::ptr::null_mut(), /* builder: nullptr -> default */
+ );
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ let dylib = LLVMOrcLLJITGetMainJITDylib(jit);
+ assert!(!dylib.is_null());
+
+ (jit, dylib)
+ };
+
+ LLJit { jit, dylib }
+ }
+
+ /// Add an LLVM IR module to the JIT. Return a [`ResourceTracker`], which when dropped, will
+ /// remove the code of the LLVM IR module from the JIT.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer or an error.
+ pub fn add_module(&self, module: Module) -> ResourceTracker<'_> {
+ let tsmod = module.into_raw_thread_safe_module();
+
+ let rt = unsafe {
+ let rt = LLVMOrcJITDylibCreateResourceTracker(self.dylib);
+ let err = LLVMOrcLLJITAddLLVMIRModuleWithRT(self.jit, rt, tsmod);
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ rt
+ };
+
+ ResourceTracker::new(rt)
+ }
+
+ /// Find the symbol with the name `sym` in the JIT.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the symbol is not found in the JIT.
+ pub fn find_symbol<F: JitFn>(&self, sym: &str) -> F {
+ let sym =
+ SmallCStr::try_from(sym).expect("Failed to convert 'sym' argument to small C string!");
+
+ unsafe {
+ let mut addr = 0u64;
+ let err = LLVMOrcLLJITLookup(self.jit, &mut addr as _, sym.as_ptr());
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ debug_assert_eq!(core::mem::size_of_val(&addr), core::mem::size_of::<F>());
+ std::mem::transmute_copy(&addr)
+ }
+ }
+
+ /// Enable lookup of dynamic symbols available in the current process from the JIT.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns an error.
+ pub fn enable_process_symbols(&self) {
+ unsafe {
+ let mut proc_syms_gen: LLVMOrcDefinitionGeneratorRef = std::ptr::null_mut();
+ let err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
+ &mut proc_syms_gen as _,
+ self.global_prefix(),
+ None, /* filter */
+ std::ptr::null_mut(), /* filter ctx */
+ );
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ LLVMOrcJITDylibAddGenerator(self.dylib, proc_syms_gen);
+ }
+ }
+
+ /// Return the global prefix character according to the LLJITs data layout.
+ fn global_prefix(&self) -> libc::c_char {
+ unsafe { LLVMOrcLLJITGetGlobalPrefix(self.jit) }
+ }
+}
+
+/// A resource handle to code added to an [`LLJit`] instance. When a `ResourceTracker` handle is
+/// dropped, the code corresponding to the handle will be removed from the JIT.
+pub struct ResourceTracker<'jit>(LLVMOrcResourceTrackerRef, PhantomData<&'jit ()>);
+
+impl<'jit> ResourceTracker<'jit> {
+ fn new(rt: LLVMOrcResourceTrackerRef) -> ResourceTracker<'jit> {
+ assert!(!rt.is_null());
+ ResourceTracker(rt, PhantomData)
+ }
+}
+
+impl Drop for ResourceTracker<'_> {
+ fn drop(&mut self) {
+ unsafe {
+ let err = LLVMOrcResourceTrackerRemove(self.0);
+
+ if let Some(err) = Error::from(err) {
+ panic!("Error: {}", err.as_str());
+ }
+
+ LLVMOrcReleaseResourceTracker(self.0);
+ };
+ }
+}
diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs
index 01ed3f2..16e6bfd 100644
--- a/src/llvm/mod.rs
+++ b/src/llvm/mod.rs
@@ -8,17 +8,28 @@
//! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM
//! API calls.
-use llvm_sys::{core::LLVMShutdown, prelude::LLVMBasicBlockRef};
+use llvm_sys::{
+ core::LLVMShutdown,
+ error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage},
+ prelude::LLVMBasicBlockRef,
+ target::{
+ LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter,
+ LLVM_InitializeNativeTarget,
+ },
+};
+use std::ffi::CStr;
use std::marker::PhantomData;
mod builder;
+mod lljit;
mod module;
mod pass_manager;
mod type_;
mod value;
pub use builder::IRBuilder;
+pub use lljit::{LLJit, ResourceTracker};
pub use module::Module;
pub use pass_manager::FunctionPassManager;
pub use type_::Type;
@@ -28,6 +39,37 @@ pub use value::{FnValue, Value};
#[derive(Copy, Clone)]
pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>);
+struct Error<'llvm>(&'llvm mut libc::c_char);
+
+impl<'llvm> Error<'llvm> {
+ fn from(err: LLVMErrorRef) -> Option<Error<'llvm>> {
+ (!err.is_null()).then(|| Error(unsafe { &mut *LLVMGetErrorMessage(err) }))
+ }
+
+ fn as_str(&self) -> &str {
+ unsafe { CStr::from_ptr(self.0) }
+ .to_str()
+ .expect("Expected valid UTF8 string from LLVM API")
+ }
+}
+
+impl Drop for Error<'_> {
+ fn drop(&mut self) {
+ unsafe {
+ LLVMDisposeErrorMessage(self.0 as *mut libc::c_char);
+ }
+ }
+}
+
+/// Initialize native target for corresponding to host (useful for jitting).
+pub fn initialize_native_taget() {
+ unsafe {
+ assert_eq!(LLVM_InitializeNativeTarget(), 0);
+ assert_eq!(LLVM_InitializeNativeAsmParser(), 0);
+ assert_eq!(LLVM_InitializeNativeAsmPrinter(), 0);
+ }
+}
+
/// Deallocate and destroy all "ManagedStatic" variables.
pub fn shutdown() {
unsafe {
diff --git a/src/main.rs b/src/main.rs
index a5b57d0..945d588 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,10 +2,11 @@ use llvm_kaleidoscope_rs::{
codegen::Codegen,
lexer::{Lexer, Token},
llvm,
- parser::Parser,
+ parser::{Parser, PrototypeAST},
Either,
};
+use std::collections::HashMap;
use std::io::Read;
fn main_loop<I>(mut parser: Parser<I>)
@@ -14,7 +15,28 @@ where
{
// Initialize LLVM module with its own context.
// We will emit LLVM IR into this module.
- let module = llvm::Module::new();
+ let mut module = llvm::Module::new();
+
+ // Create a new JIT, based on the LLVM LLJIT.
+ let jit = llvm::LLJit::new();
+
+ // Enable lookup of dynamic symbols in the current process from the JIT.
+ jit.enable_process_symbols();
+
+ // Keep track of prototype names to their respective ASTs.
+ //
+ // This is useful since we jit every function definition into its own LLVM module.
+ // To allow calling functions defined in previous LLVM modules we keep track of their
+ // prototypes and generate IR for their declarations when they are called from another module.
+ let mut fn_protos: HashMap<String, PrototypeAST> = HashMap::new();
+
+ // When adding an IR module to the JIT, it will hand out a ResourceTracker. When the
+ // ResourceTracker is dropped, the code generated from the corresponding module will be removed
+ // from the JIT.
+ //
+ // For each function we want to keep the code generated for the last definition, hence we need
+ // to keep their ResourceTracker alive.
+ let mut fn_jit_rt: HashMap<String, llvm::ResourceTracker> = HashMap::new();
loop {
match parser.cur_tok() {
@@ -25,9 +47,25 @@ where
}
Token::Def => match parser.parse_definition() {
Ok(func) => {
- println!("Parse 'def'\n{:?}", func);
- if let Ok(func) = Codegen::compile(&module, Either::B(&func)) {
- func.dump();
+ println!("Parse 'def'");
+ let func_name = &func.0 .0;
+
+ // If we already jitted that function, remove the last definition from the JIT
+ // by dropping the corresponding ResourceTracker.
+ fn_jit_rt.remove(func_name);
+
+ if let Ok(func_ir) = Codegen::compile(&module, &mut fn_protos, Either::B(&func))
+ {
+ func_ir.dump();
+
+ // Add module to the JIT.
+ let rt = jit.add_module(module);
+
+ // Keep track of the ResourceTracker to keep the module code in the JIT.
+ fn_jit_rt.insert(func_name.to_string(), rt);
+
+ // Initialize a new module.
+ module = llvm::Module::new();
}
}
Err(err) => {
@@ -37,9 +75,14 @@ where
},
Token::Extern => match parser.parse_extern() {
Ok(proto) => {
- println!("Parse 'extern'\n{:?}", proto);
- if let Ok(proto) = Codegen::compile(&module, Either::A(&proto)) {
- proto.dump();
+ println!("Parse 'extern'");
+ if let Ok(proto_ir) =
+ Codegen::compile(&module, &mut fn_protos, Either::A(&proto))
+ {
+ proto_ir.dump();
+
+ // Keep track of external function declaration.
+ fn_protos.insert(proto.0.clone(), proto);
}
}
Err(err) => {
@@ -49,9 +92,21 @@ where
},
_ => match parser.parse_top_level_expr() {
Ok(func) => {
- println!("Parse top-level expression\n{:?}", func);
- if let Ok(func) = Codegen::compile(&module, Either::B(&func)) {
+ println!("Parse top-level expression");
+ if let Ok(func) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) {
func.dump();
+
+ // Add module to the JIT. Code will be removed when `_rt` is dropped.
+ let _rt = jit.add_module(module);
+
+ // Initialize a new module.
+ module = llvm::Module::new();
+
+ // Call the top level expression.
+ let fp = jit.find_symbol::<unsafe extern "C" fn() -> f64>("__anon_expr");
+ unsafe {
+ println!("Evaluated to {}", fp());
+ }
}
}
Err(err) => {
@@ -83,6 +138,9 @@ fn main() {
// Throw first coin and initialize cur_tok.
parser.get_next_token();
+ // Initialize native target for jitting.
+ llvm::initialize_native_taget();
+
main_loop(parser);
// De-allocate managed static LLVM data.
diff --git a/src/parser.rs b/src/parser.rs
index af69a87..b3a26cb 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -18,7 +18,7 @@ pub enum ExprAST {
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
-#[derive(Debug, PartialEq)]
+#[derive(Debug, PartialEq, Clone)]
pub struct PrototypeAST(pub String, pub Vec<String>);
/// FunctionAST - This class represents a function definition itself.
@@ -307,7 +307,7 @@ where
/// Implement `std::unique_ptr<FunctionAST> ParseTopLevelExpr();` from the tutorial.
pub fn parse_top_level_expr(&mut self) -> ParseResult<FunctionAST> {
let e = self.parse_expression()?;
- let proto = PrototypeAST("".into(), Vec::new());
+ let proto = PrototypeAST("__anon_expr".into(), Vec::new());
Ok(FunctionAST(proto, e))
}
}