use std::collections::HashMap; use crate::llvm::{Builder, FnValue, Module, Value}; use crate::parser::{ExprAST, FunctionAST, PrototypeAST}; use crate::Either; type CodegenResult = Result; /// Code generator from kaleidoscope AST to LLVM IR. pub struct Codegen<'llvm, 'a> { module: &'llvm Module, builder: &'a Builder<'llvm>, } impl<'llvm, 'a> Codegen<'llvm, 'a> { /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`. pub fn compile( module: &'llvm Module, compilee: Either<&PrototypeAST, &FunctionAST>, ) -> CodegenResult> { let cg = Codegen { module, builder: &Builder::with_ctx(module), }; let mut variables = HashMap::new(); match compilee { Either::A(proto) => Ok(cg.codegen_prototype(proto)), Either::B(func) => cg.codegen_function(func, &mut variables), } } fn codegen_expr( &self, expr: &ExprAST, named_values: &mut HashMap<&'llvm str, Value<'llvm>>, ) -> CodegenResult> { match expr { ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)), ExprAST::Variable(name) => match named_values.get(name.as_str()) { Some(value) => Ok(*value), None => Err("Unknown variable name".into()), }, ExprAST::Binary(binop, lhs, rhs) => { let l = self.codegen_expr(lhs, named_values)?; let r = self.codegen_expr(rhs, named_values)?; match binop { '+' => Ok(self.builder.fadd(l, r)), '-' => Ok(self.builder.fsub(l, r)), '*' => Ok(self.builder.fmul(l, r)), '<' => { let res = self.builder.fcmpult(l, r); // Turn bool into f64. Ok(self.builder.uitofp(res, self.module.type_f64())) } _ => Err("invalid binary operator".into()), } } ExprAST::Call(callee, args) => match self.module.get_fn(callee) { Some(callee) => { if callee.args() != args.len() { return Err("Incorrect # arguments passed".into()); } // Generate code for function argument expressions. let mut args: Vec> = args .iter() .map(|arg| self.codegen_expr(arg, named_values)) .collect::>()?; Ok(self.builder.call(callee, &mut args)) } None => Err("Unknown function referenced".into()), }, } } fn codegen_prototype(&self, PrototypeAST(name, args): &PrototypeAST) -> FnValue<'llvm> { let type_f64 = self.module.type_f64(); let mut doubles = Vec::new(); doubles.resize(args.len(), type_f64); // Build the function type: fn(f64, f64, ..) -> f64 let ft = self.module.type_fn(&mut doubles, type_f64); // Create the function declaration. let f = self.module.add_fn(name, ft); // Set the names of the function arguments. for idx in 0..f.args() { f.arg(idx).set_name(&args[idx]); } f } fn codegen_function( &self, FunctionAST(proto, body): &FunctionAST, named_values: &mut HashMap<&'llvm str, Value<'llvm>>, ) -> CodegenResult> { let the_function = match self.module.get_fn(&proto.0) { Some(f) => f, None => self.codegen_prototype(proto), }; if the_function.basic_blocks() > 0 { return Err("Function cannot be redefined.".into()); } // Create entry basic block to insert code. let bb = self.module.append_basic_block(the_function); self.builder.pos_at_end(bb); // New scope, clear the map with the function args. named_values.clear(); // Update the map with the current functions args. for idx in 0..the_function.args() { let arg = the_function.arg(idx); named_values.insert(arg.get_name(), arg); } // Codegen function body. if let Ok(ret) = self.codegen_expr(body, named_values) { self.builder.ret(ret); assert!(the_function.verify()); Ok(the_function) } else { todo!("Failed to codegen function body, erase from module!"); } } }