use std::collections::HashMap; use crate::llvm::{FnValue, FunctionPassManager, IRBuilder, Module, Value}; use crate::parser::{ExprAST, FunctionAST, PrototypeAST}; use crate::Either; type CodegenResult = Result; /// Code generator from kaleidoscope AST to LLVM IR. pub struct Codegen<'llvm, 'a> { module: &'llvm Module, builder: &'a IRBuilder<'llvm>, fpm: &'a FunctionPassManager<'llvm>, fn_protos: &'a mut HashMap, } impl<'llvm, 'a> Codegen<'llvm, 'a> { /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`. pub fn compile( module: &'llvm Module, fn_protos: &mut HashMap, compilee: Either<&PrototypeAST, &FunctionAST>, ) -> CodegenResult> { let mut cg = Codegen { module, builder: &IRBuilder::with_ctx(module), fpm: &FunctionPassManager::with_ctx(module), fn_protos, }; let mut variables = HashMap::new(); match compilee { Either::A(proto) => Ok(cg.codegen_prototype(proto)), Either::B(func) => cg.codegen_function(func, &mut variables), } } fn codegen_expr( &self, expr: &ExprAST, named_values: &mut HashMap>, ) -> CodegenResult> { match expr { ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)), ExprAST::Variable(name) => match named_values.get(name.as_str()) { Some(value) => Ok(*value), None => Err("Unknown variable name".into()), }, ExprAST::Binary(binop, lhs, rhs) => { let l = self.codegen_expr(lhs, named_values)?; let r = self.codegen_expr(rhs, named_values)?; match binop { '+' => Ok(self.builder.fadd(l, r)), '-' => Ok(self.builder.fsub(l, r)), '*' => Ok(self.builder.fmul(l, r)), '<' => { let res = self.builder.fcmpult(l, r); // Turn bool into f64. Ok(self.builder.uitofp(res, self.module.type_f64())) } _ => Err("invalid binary operator".into()), } } ExprAST::Call(callee, args) => match self.get_function(callee) { Some(callee) => { if callee.args() != args.len() { return Err("Incorrect # arguments passed".into()); } // Generate code for function argument expressions. let mut args: Vec> = args .iter() .map(|arg| self.codegen_expr(arg, named_values)) .collect::>()?; Ok(self.builder.call(callee, &mut args)) } None => Err("Unknown function referenced".into()), }, ExprAST::If { cond, then, else_ } => { // For 'if' expressions we are building the following CFG. // // ; cond // br // | // +-----+------+ // v v // ; then ; else // | | // +-----+------+ // v // ; merge // phi then, else // ret phi let cond_v = { // Codgen 'cond' expression. let v = self.codegen_expr(cond, named_values)?; // Compare 'v' against '0' as 'one = ordered not equal'. self.builder .fcmpone(v, self.module.type_f64().const_f64(0f64)) }; // Get the function we are currently inserting into. let the_function = self.builder.get_insert_block().get_parent(); // Create basic blocks for the 'then' / 'else' expressions as well as the return // instruction ('merge'). // // Append the 'then' basic block to the function, don't insert the 'else' and // 'merge' basic blocks yet. let then_bb = self.module.append_basic_block(the_function); let else_bb = self.module.create_basic_block(); let merge_bb = self.module.create_basic_block(); // Create a conditional branch based on the result of the 'cond' expression. self.builder.cond_br(cond_v, then_bb, else_bb); // Move to 'then' basic block and codgen the 'then' expression. self.builder.pos_at_end(then_bb); let then_v = self.codegen_expr(then, named_values)?; // Create unconditional branch to 'merge' block. self.builder.br(merge_bb); // Update reference to current basic block (in case the 'then' expression added new // basic blocks). let then_bb = self.builder.get_insert_block(); // Now append the 'else' basic block to the function. the_function.append_basic_block(else_bb); // Move to 'else' basic block and codgen the 'else' expression. self.builder.pos_at_end(else_bb); let else_v = self.codegen_expr(else_, named_values)?; // Create unconditional branch to 'merge' block. self.builder.br(merge_bb); // Update reference to current basic block (in case the 'else' expression added new // basic blocks). let else_bb = self.builder.get_insert_block(); // Now append the 'merge' basic block to the function. the_function.append_basic_block(merge_bb); // Move to 'merge' basic block. self.builder.pos_at_end(merge_bb); // Codegen the phi node returning the appropriate value depending on the branch // condition. let phi = self.builder.phi( self.module.type_f64(), &[(then_v, then_bb), (else_v, else_bb)], ); Ok(*phi) } ExprAST::For { var, start, end, step, body, } => { // For 'for' expression we build the following structure. // // entry: // init = start expression // br loop // loop: // i = phi [%init, %entry], [%new_i, %loop] // ; loop body ... // new_i = increment %i by step expression // ; check end condition and branch // end: // Compute initial value for the loop variable. let start_val = self.codegen_expr(start, named_values)?; let the_function = self.builder.get_insert_block().get_parent(); // Get current basic block (used in the loop variable phi node). let entry_bb = self.builder.get_insert_block(); // Add new basic block to emit loop body. let loop_bb = self.module.append_basic_block(the_function); self.builder.br(loop_bb); self.builder.pos_at_end(loop_bb); // Build phi not to pick loop variable in case we come from the 'entry' block. // Which is the case when we enter the loop for the first time. // We will add another incoming value once we computed the updated loop variable // below. let variable = self .builder .phi(self.module.type_f64(), &[(start_val, entry_bb)]); // Insert the loop variable into the named values map that it can be referenced // from the body as well as the end condition. // In case the loop variable shadows an existing variable remember the shared one. let old_val = named_values.insert(var.into(), *variable); // Generate the loop body. self.codegen_expr(body, named_values)?; // Generate step value expression if available else use '1'. let step_val = if let Some(step) = step { self.codegen_expr(step, named_values)? } else { self.module.type_f64().const_f64(1f64) }; // Increment loop variable. let next_var = self.builder.fadd(*variable, step_val); // Generate the loop end condition. let end_cond = self.codegen_expr(end, named_values)?; let end_cond = self .builder .fcmpone(end_cond, self.module.type_f64().const_f64(0f64)); // Get current basic block. let loop_end_bb = self.builder.get_insert_block(); // Add new basic block following the loop. let after_bb = self.module.append_basic_block(the_function); // Register additional incoming value for the loop variable. This will choose the // updated loop variable if we are iterating in the loop. variable.add_incoming(next_var, loop_end_bb); // Branch depending on the loop end condition. self.builder.cond_br(end_cond, loop_bb, after_bb); self.builder.pos_at_end(after_bb); // Restore the shadowed variable if there was one. if let Some(old_val) = old_val { // We inserted 'var' above so it must exist. *named_values.get_mut(var).unwrap() = old_val; } else { named_values.remove(var); } // Loops just always return 0. Ok(self.module.type_f64().const_f64(0f64)) } } } fn codegen_prototype(&self, PrototypeAST(name, args): &PrototypeAST) -> FnValue<'llvm> { let type_f64 = self.module.type_f64(); let mut doubles = Vec::new(); doubles.resize(args.len(), type_f64); // Build the function type: fn(f64, f64, ..) -> f64 let ft = self.module.type_fn(&mut doubles, type_f64); // Create the function declaration. let f = self.module.add_fn(name, ft); // Set the names of the function arguments. for idx in 0..f.args() { f.arg(idx).set_name(&args[idx]); } f } fn codegen_function( &mut self, FunctionAST(proto, body): &FunctionAST, named_values: &mut HashMap>, ) -> CodegenResult> { // Insert the function prototype into the `fn_protos` map to keep track for re-generating // declarations in other modules. self.fn_protos.insert(proto.0.clone(), proto.clone()); let the_function = self.get_function(&proto.0) .expect("If proto not already generated, get_function will do for us since we updated fn_protos before-hand!"); if the_function.basic_blocks() > 0 { return Err("Function cannot be redefined.".into()); } // Create entry basic block to insert code. let bb = self.module.append_basic_block(the_function); self.builder.pos_at_end(bb); // New scope, clear the map with the function args. named_values.clear(); // Update the map with the current functions args. for idx in 0..the_function.args() { let arg = the_function.arg(idx); named_values.insert(arg.get_name().into(), arg); } // Codegen function body. if let Ok(ret) = self.codegen_expr(body, named_values) { self.builder.ret(ret); assert!(the_function.verify()); // Run the optimization passes on the function. self.fpm.run(the_function); Ok(the_function) } else { todo!("Failed to codegen function body, erase from module!"); } } /// Lookup function with `name` in the LLVM module and return the corresponding value reference. /// If the function is not available in the module, check if the prototype is known and codegen /// it. /// Return [`None`] if the prototype is not known. fn get_function(&self, name: &str) -> Option> { let callee = match self.module.get_fn(name) { Some(callee) => callee, None => { let proto = self.fn_protos.get(name)?; self.codegen_prototype(proto) } }; Some(callee) } }