From 6eb6ad9f574c783d471f6a863299af25b6f5a8c7 Mon Sep 17 00:00:00 2001 From: Johannes Stoelp Date: Sat, 25 Sep 2021 00:48:45 +0200 Subject: ch4: added jit --- src/codegen.rs | 35 ++++++++++--- src/llvm/lljit.rs | 151 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/llvm/mod.rs | 44 +++++++++++++++- src/main.rs | 78 ++++++++++++++++++++++++---- src/parser.rs | 4 +- 5 files changed, 292 insertions(+), 20 deletions(-) create mode 100644 src/llvm/lljit.rs (limited to 'src') diff --git a/src/codegen.rs b/src/codegen.rs index 08c3039..61634ad 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -11,18 +11,21 @@ pub struct Codegen<'llvm, 'a> { module: &'llvm Module, builder: &'a IRBuilder<'llvm>, fpm: &'a FunctionPassManager<'llvm>, + fn_protos: &'a mut HashMap, } impl<'llvm, 'a> Codegen<'llvm, 'a> { /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`. pub fn compile( module: &'llvm Module, + fn_protos: &mut HashMap, compilee: Either<&PrototypeAST, &FunctionAST>, ) -> CodegenResult> { - let cg = Codegen { + let mut cg = Codegen { module, builder: &IRBuilder::with_ctx(module), fpm: &FunctionPassManager::with_ctx(module), + fn_protos, }; let mut variables = HashMap::new(); @@ -59,7 +62,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { _ => Err("invalid binary operator".into()), } } - ExprAST::Call(callee, args) => match self.module.get_fn(callee) { + ExprAST::Call(callee, args) => match self.get_function(callee) { Some(callee) => { if callee.args() != args.len() { return Err("Incorrect # arguments passed".into()); @@ -99,14 +102,16 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { } fn codegen_function( - &self, + &mut self, FunctionAST(proto, body): &FunctionAST, named_values: &mut HashMap<&'llvm str, Value<'llvm>>, ) -> CodegenResult> { - let the_function = match self.module.get_fn(&proto.0) { - Some(f) => f, - None => self.codegen_prototype(proto), - }; + // Insert the function prototype into the `fn_protos` map to keep track for re-generating + // declarations in other modules. + self.fn_protos.insert(proto.0.clone(), proto.clone()); + + let the_function = self.get_function(&proto.0) + .expect("If proto not already generated, get_function will do for us since we updated fn_protos before-hand!"); if the_function.basic_blocks() > 0 { return Err("Function cannot be redefined.".into()); @@ -138,4 +143,20 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { todo!("Failed to codegen function body, erase from module!"); } } + + /// Lookup function with `name` in the LLVM module and return the corresponding value reference. + /// If the function is not available in the module, check if the prototype is known and codegen + /// it. + /// Return [`None`] if the prototype is not known. + fn get_function(&self, name: &str) -> Option> { + let callee = match self.module.get_fn(name) { + Some(callee) => callee, + None => { + let proto = self.fn_protos.get(name)?; + self.codegen_prototype(proto) + } + }; + + Some(callee) + } } diff --git a/src/llvm/lljit.rs b/src/llvm/lljit.rs new file mode 100644 index 0000000..88c7059 --- /dev/null +++ b/src/llvm/lljit.rs @@ -0,0 +1,151 @@ +use llvm_sys::orc2::{ + lljit::{ + LLVMOrcCreateLLJIT, LLVMOrcLLJITAddLLVMIRModuleWithRT, LLVMOrcLLJITGetGlobalPrefix, + LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup, LLVMOrcLLJITRef, + }, + LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess, LLVMOrcDefinitionGeneratorRef, + LLVMOrcJITDylibAddGenerator, LLVMOrcJITDylibCreateResourceTracker, LLVMOrcJITDylibRef, + LLVMOrcReleaseResourceTracker, LLVMOrcResourceTrackerRef, LLVMOrcResourceTrackerRemove, +}; + +use std::convert::TryFrom; +use std::marker::PhantomData; + +use super::{Error, Module}; +use crate::SmallCStr; + +/// Marker trait to constrain function signatures that can be looked up in the JIT. +pub trait JitFn {} + +impl JitFn for unsafe extern "C" fn() -> f64 {} + +pub struct LLJit { + jit: LLVMOrcLLJITRef, + dylib: LLVMOrcJITDylibRef, +} + +impl LLJit { + /// Create a new LLJit instance. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or an error. + pub fn new() -> LLJit { + let (jit, dylib) = unsafe { + let mut jit = std::ptr::null_mut(); + let err = LLVMOrcCreateLLJIT( + &mut jit as _, + std::ptr::null_mut(), /* builder: nullptr -> default */ + ); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + let dylib = LLVMOrcLLJITGetMainJITDylib(jit); + assert!(!dylib.is_null()); + + (jit, dylib) + }; + + LLJit { jit, dylib } + } + + /// Add an LLVM IR module to the JIT. Return a [`ResourceTracker`], which when dropped, will + /// remove the code of the LLVM IR module from the JIT. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or an error. + pub fn add_module(&self, module: Module) -> ResourceTracker<'_> { + let tsmod = module.into_raw_thread_safe_module(); + + let rt = unsafe { + let rt = LLVMOrcJITDylibCreateResourceTracker(self.dylib); + let err = LLVMOrcLLJITAddLLVMIRModuleWithRT(self.jit, rt, tsmod); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + rt + }; + + ResourceTracker::new(rt) + } + + /// Find the symbol with the name `sym` in the JIT. + /// + /// # Panics + /// + /// Panics if the symbol is not found in the JIT. + pub fn find_symbol(&self, sym: &str) -> F { + let sym = + SmallCStr::try_from(sym).expect("Failed to convert 'sym' argument to small C string!"); + + unsafe { + let mut addr = 0u64; + let err = LLVMOrcLLJITLookup(self.jit, &mut addr as _, sym.as_ptr()); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + debug_assert_eq!(core::mem::size_of_val(&addr), core::mem::size_of::()); + std::mem::transmute_copy(&addr) + } + } + + /// Enable lookup of dynamic symbols available in the current process from the JIT. + /// + /// # Panics + /// + /// Panics if LLVM API returns an error. + pub fn enable_process_symbols(&self) { + unsafe { + let mut proc_syms_gen: LLVMOrcDefinitionGeneratorRef = std::ptr::null_mut(); + let err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess( + &mut proc_syms_gen as _, + self.global_prefix(), + None, /* filter */ + std::ptr::null_mut(), /* filter ctx */ + ); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + LLVMOrcJITDylibAddGenerator(self.dylib, proc_syms_gen); + } + } + + /// Return the global prefix character according to the LLJITs data layout. + fn global_prefix(&self) -> libc::c_char { + unsafe { LLVMOrcLLJITGetGlobalPrefix(self.jit) } + } +} + +/// A resource handle to code added to an [`LLJit`] instance. When a `ResourceTracker` handle is +/// dropped, the code corresponding to the handle will be removed from the JIT. +pub struct ResourceTracker<'jit>(LLVMOrcResourceTrackerRef, PhantomData<&'jit ()>); + +impl<'jit> ResourceTracker<'jit> { + fn new(rt: LLVMOrcResourceTrackerRef) -> ResourceTracker<'jit> { + assert!(!rt.is_null()); + ResourceTracker(rt, PhantomData) + } +} + +impl Drop for ResourceTracker<'_> { + fn drop(&mut self) { + unsafe { + let err = LLVMOrcResourceTrackerRemove(self.0); + + if let Some(err) = Error::from(err) { + panic!("Error: {}", err.as_str()); + } + + LLVMOrcReleaseResourceTracker(self.0); + }; + } +} diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 01ed3f2..16e6bfd 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -8,17 +8,28 @@ //! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM //! API calls. -use llvm_sys::{core::LLVMShutdown, prelude::LLVMBasicBlockRef}; +use llvm_sys::{ + core::LLVMShutdown, + error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage}, + prelude::LLVMBasicBlockRef, + target::{ + LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter, + LLVM_InitializeNativeTarget, + }, +}; +use std::ffi::CStr; use std::marker::PhantomData; mod builder; +mod lljit; mod module; mod pass_manager; mod type_; mod value; pub use builder::IRBuilder; +pub use lljit::{LLJit, ResourceTracker}; pub use module::Module; pub use pass_manager::FunctionPassManager; pub use type_::Type; @@ -28,6 +39,37 @@ pub use value::{FnValue, Value}; #[derive(Copy, Clone)] pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); +struct Error<'llvm>(&'llvm mut libc::c_char); + +impl<'llvm> Error<'llvm> { + fn from(err: LLVMErrorRef) -> Option> { + (!err.is_null()).then(|| Error(unsafe { &mut *LLVMGetErrorMessage(err) })) + } + + fn as_str(&self) -> &str { + unsafe { CStr::from_ptr(self.0) } + .to_str() + .expect("Expected valid UTF8 string from LLVM API") + } +} + +impl Drop for Error<'_> { + fn drop(&mut self) { + unsafe { + LLVMDisposeErrorMessage(self.0 as *mut libc::c_char); + } + } +} + +/// Initialize native target for corresponding to host (useful for jitting). +pub fn initialize_native_taget() { + unsafe { + assert_eq!(LLVM_InitializeNativeTarget(), 0); + assert_eq!(LLVM_InitializeNativeAsmParser(), 0); + assert_eq!(LLVM_InitializeNativeAsmPrinter(), 0); + } +} + /// Deallocate and destroy all "ManagedStatic" variables. pub fn shutdown() { unsafe { diff --git a/src/main.rs b/src/main.rs index a5b57d0..945d588 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,10 +2,11 @@ use llvm_kaleidoscope_rs::{ codegen::Codegen, lexer::{Lexer, Token}, llvm, - parser::Parser, + parser::{Parser, PrototypeAST}, Either, }; +use std::collections::HashMap; use std::io::Read; fn main_loop(mut parser: Parser) @@ -14,7 +15,28 @@ where { // Initialize LLVM module with its own context. // We will emit LLVM IR into this module. - let module = llvm::Module::new(); + let mut module = llvm::Module::new(); + + // Create a new JIT, based on the LLVM LLJIT. + let jit = llvm::LLJit::new(); + + // Enable lookup of dynamic symbols in the current process from the JIT. + jit.enable_process_symbols(); + + // Keep track of prototype names to their respective ASTs. + // + // This is useful since we jit every function definition into its own LLVM module. + // To allow calling functions defined in previous LLVM modules we keep track of their + // prototypes and generate IR for their declarations when they are called from another module. + let mut fn_protos: HashMap = HashMap::new(); + + // When adding an IR module to the JIT, it will hand out a ResourceTracker. When the + // ResourceTracker is dropped, the code generated from the corresponding module will be removed + // from the JIT. + // + // For each function we want to keep the code generated for the last definition, hence we need + // to keep their ResourceTracker alive. + let mut fn_jit_rt: HashMap = HashMap::new(); loop { match parser.cur_tok() { @@ -25,9 +47,25 @@ where } Token::Def => match parser.parse_definition() { Ok(func) => { - println!("Parse 'def'\n{:?}", func); - if let Ok(func) = Codegen::compile(&module, Either::B(&func)) { - func.dump(); + println!("Parse 'def'"); + let func_name = &func.0 .0; + + // If we already jitted that function, remove the last definition from the JIT + // by dropping the corresponding ResourceTracker. + fn_jit_rt.remove(func_name); + + if let Ok(func_ir) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) + { + func_ir.dump(); + + // Add module to the JIT. + let rt = jit.add_module(module); + + // Keep track of the ResourceTracker to keep the module code in the JIT. + fn_jit_rt.insert(func_name.to_string(), rt); + + // Initialize a new module. + module = llvm::Module::new(); } } Err(err) => { @@ -37,9 +75,14 @@ where }, Token::Extern => match parser.parse_extern() { Ok(proto) => { - println!("Parse 'extern'\n{:?}", proto); - if let Ok(proto) = Codegen::compile(&module, Either::A(&proto)) { - proto.dump(); + println!("Parse 'extern'"); + if let Ok(proto_ir) = + Codegen::compile(&module, &mut fn_protos, Either::A(&proto)) + { + proto_ir.dump(); + + // Keep track of external function declaration. + fn_protos.insert(proto.0.clone(), proto); } } Err(err) => { @@ -49,9 +92,21 @@ where }, _ => match parser.parse_top_level_expr() { Ok(func) => { - println!("Parse top-level expression\n{:?}", func); - if let Ok(func) = Codegen::compile(&module, Either::B(&func)) { + println!("Parse top-level expression"); + if let Ok(func) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) { func.dump(); + + // Add module to the JIT. Code will be removed when `_rt` is dropped. + let _rt = jit.add_module(module); + + // Initialize a new module. + module = llvm::Module::new(); + + // Call the top level expression. + let fp = jit.find_symbol:: f64>("__anon_expr"); + unsafe { + println!("Evaluated to {}", fp()); + } } } Err(err) => { @@ -83,6 +138,9 @@ fn main() { // Throw first coin and initialize cur_tok. parser.get_next_token(); + // Initialize native target for jitting. + llvm::initialize_native_taget(); + main_loop(parser); // De-allocate managed static LLVM data. diff --git a/src/parser.rs b/src/parser.rs index af69a87..b3a26cb 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -18,7 +18,7 @@ pub enum ExprAST { /// PrototypeAST - This class represents the "prototype" for a function, /// which captures its name, and its argument names (thus implicitly the number /// of arguments the function takes). -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub struct PrototypeAST(pub String, pub Vec); /// FunctionAST - This class represents a function definition itself. @@ -307,7 +307,7 @@ where /// Implement `std::unique_ptr ParseTopLevelExpr();` from the tutorial. pub fn parse_top_level_expr(&mut self) -> ParseResult { let e = self.parse_expression()?; - let proto = PrototypeAST("".into(), Vec::new()); + let proto = PrototypeAST("__anon_expr".into(), Vec::new()); Ok(FunctionAST(proto, e)) } } -- cgit v1.2.3