aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/main.rs
diff options
context:
space:
mode:
authorJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-25 00:48:45 +0200
committerJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-25 00:48:45 +0200
commit6eb6ad9f574c783d471f6a863299af25b6f5a8c7 (patch)
tree38c087654f2c703d7d4c6afbf342aa9dd65557c9 /src/main.rs
parent425ba77074347f71283f75839224f78bd94f2e10 (diff)
downloadllvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.tar.gz
llvm-kaleidoscope-rs-6eb6ad9f574c783d471f6a863299af25b6f5a8c7.zip
ch4: added jit
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs78
1 files changed, 68 insertions, 10 deletions
diff --git a/src/main.rs b/src/main.rs
index a5b57d0..945d588 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,10 +2,11 @@ use llvm_kaleidoscope_rs::{
codegen::Codegen,
lexer::{Lexer, Token},
llvm,
- parser::Parser,
+ parser::{Parser, PrototypeAST},
Either,
};
+use std::collections::HashMap;
use std::io::Read;
fn main_loop<I>(mut parser: Parser<I>)
@@ -14,7 +15,28 @@ where
{
// Initialize LLVM module with its own context.
// We will emit LLVM IR into this module.
- let module = llvm::Module::new();
+ let mut module = llvm::Module::new();
+
+ // Create a new JIT, based on the LLVM LLJIT.
+ let jit = llvm::LLJit::new();
+
+ // Enable lookup of dynamic symbols in the current process from the JIT.
+ jit.enable_process_symbols();
+
+ // Keep track of prototype names to their respective ASTs.
+ //
+ // This is useful since we jit every function definition into its own LLVM module.
+ // To allow calling functions defined in previous LLVM modules we keep track of their
+ // prototypes and generate IR for their declarations when they are called from another module.
+ let mut fn_protos: HashMap<String, PrototypeAST> = HashMap::new();
+
+ // When adding an IR module to the JIT, it will hand out a ResourceTracker. When the
+ // ResourceTracker is dropped, the code generated from the corresponding module will be removed
+ // from the JIT.
+ //
+ // For each function we want to keep the code generated for the last definition, hence we need
+ // to keep their ResourceTracker alive.
+ let mut fn_jit_rt: HashMap<String, llvm::ResourceTracker> = HashMap::new();
loop {
match parser.cur_tok() {
@@ -25,9 +47,25 @@ where
}
Token::Def => match parser.parse_definition() {
Ok(func) => {
- println!("Parse 'def'\n{:?}", func);
- if let Ok(func) = Codegen::compile(&module, Either::B(&func)) {
- func.dump();
+ println!("Parse 'def'");
+ let func_name = &func.0 .0;
+
+ // If we already jitted that function, remove the last definition from the JIT
+ // by dropping the corresponding ResourceTracker.
+ fn_jit_rt.remove(func_name);
+
+ if let Ok(func_ir) = Codegen::compile(&module, &mut fn_protos, Either::B(&func))
+ {
+ func_ir.dump();
+
+ // Add module to the JIT.
+ let rt = jit.add_module(module);
+
+ // Keep track of the ResourceTracker to keep the module code in the JIT.
+ fn_jit_rt.insert(func_name.to_string(), rt);
+
+ // Initialize a new module.
+ module = llvm::Module::new();
}
}
Err(err) => {
@@ -37,9 +75,14 @@ where
},
Token::Extern => match parser.parse_extern() {
Ok(proto) => {
- println!("Parse 'extern'\n{:?}", proto);
- if let Ok(proto) = Codegen::compile(&module, Either::A(&proto)) {
- proto.dump();
+ println!("Parse 'extern'");
+ if let Ok(proto_ir) =
+ Codegen::compile(&module, &mut fn_protos, Either::A(&proto))
+ {
+ proto_ir.dump();
+
+ // Keep track of external function declaration.
+ fn_protos.insert(proto.0.clone(), proto);
}
}
Err(err) => {
@@ -49,9 +92,21 @@ where
},
_ => match parser.parse_top_level_expr() {
Ok(func) => {
- println!("Parse top-level expression\n{:?}", func);
- if let Ok(func) = Codegen::compile(&module, Either::B(&func)) {
+ println!("Parse top-level expression");
+ if let Ok(func) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) {
func.dump();
+
+ // Add module to the JIT. Code will be removed when `_rt` is dropped.
+ let _rt = jit.add_module(module);
+
+ // Initialize a new module.
+ module = llvm::Module::new();
+
+ // Call the top level expression.
+ let fp = jit.find_symbol::<unsafe extern "C" fn() -> f64>("__anon_expr");
+ unsafe {
+ println!("Evaluated to {}", fp());
+ }
}
}
Err(err) => {
@@ -83,6 +138,9 @@ fn main() {
// Throw first coin and initialize cur_tok.
parser.get_next_token();
+ // Initialize native target for jitting.
+ llvm::initialize_native_taget();
+
main_loop(parser);
// De-allocate managed static LLVM data.