diff options
author | Johannes Stoelp <johannes.stoelp@gmail.com> | 2022-03-27 22:20:06 +0200 |
---|---|---|
committer | Johannes Stoelp <johannes.stoelp@gmail.com> | 2022-03-27 22:20:06 +0200 |
commit | 21ea78e57fa480d472d3660881e91813f7b18820 (patch) | |
tree | 476b2721ba44d41d9b2faf003701fdc647c9259e /src | |
parent | 4f6dd49df3f19204694fcea55f38efd9c5118bf2 (diff) | |
download | llvm-kaleidoscope-rs-21ea78e57fa480d472d3660881e91813f7b18820.tar.gz llvm-kaleidoscope-rs-21ea78e57fa480d472d3660881e91813f7b18820.zip |
ch5: added for loop
Diffstat (limited to 'src')
-rw-r--r-- | src/codegen.rs | 98 | ||||
-rw-r--r-- | src/lexer.rs | 11 | ||||
-rw-r--r-- | src/llvm/builder.rs | 16 | ||||
-rw-r--r-- | src/llvm/mod.rs | 2 | ||||
-rw-r--r-- | src/llvm/value.rs | 72 | ||||
-rw-r--r-- | src/parser.rs | 125 |
6 files changed, 303 insertions, 21 deletions
diff --git a/src/codegen.rs b/src/codegen.rs index 25e8c42..b2d15e7 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -38,7 +38,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { fn codegen_expr( &self, expr: &ExprAST, - named_values: &mut HashMap<&'llvm str, Value<'llvm>>, + named_values: &mut HashMap<String, Value<'llvm>>, ) -> CodegenResult<Value<'llvm>> { match expr { ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)), @@ -97,7 +97,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { let cond_v = { // Codgen 'cond' expression. let v = self.codegen_expr(cond, named_values)?; - // Convert condition to bool. + // Compare 'v' against '0' as 'one = ordered not equal'. self.builder .fcmpone(v, self.module.type_f64().const_f64(0f64)) }; @@ -148,7 +148,95 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { &[(then_v, then_bb), (else_v, else_bb)], ); - Ok(phi) + Ok(*phi) + } + ExprAST::For { + var, + start, + end, + step, + body, + } => { + // For 'for' expression we build the following structure. + // + // entry: + // init = start expression + // br loop + // loop: + // i = phi [%init, %entry], [%new_i, %loop] + // ; loop body ... + // new_i = increment %i by step expression + // ; check end condition and branch + // end: + + // Compute initial value for the loop variable. + let start_val = self.codegen_expr(start, named_values)?; + + let the_function = self.builder.get_insert_block().get_parent(); + // Get current basic block (used in the loop variable phi node). + let entry_bb = self.builder.get_insert_block(); + // Add new basic block to emit loop body. + let loop_bb = self.module.append_basic_block(the_function); + + self.builder.br(loop_bb); + self.builder.pos_at_end(loop_bb); + + // Build phi not to pick loop variable in case we come from the 'entry' block. + // Which is the case when we enter the loop for the first time. + // We will add another incoming value once we computed the updated loop variable + // below. + let variable = self + .builder + .phi(self.module.type_f64(), &[(start_val, entry_bb)]); + + // Insert the loop variable into the named values map that it can be referenced + // from the body as well as the end condition. + // In case the loop variable shadows an existing variable remember the shared one. + let old_val = named_values.insert(var.into(), *variable); + + // Generate the loop body. + self.codegen_expr(body, named_values)?; + + // Generate step value expression if available else use '1'. + let step_val = if let Some(step) = step { + self.codegen_expr(step, named_values)? + } else { + self.module.type_f64().const_f64(1f64) + }; + + // Increment loop variable. + let next_var = self.builder.fadd(*variable, step_val); + + // Generate the loop end condition. + let end_cond = self.codegen_expr(end, named_values)?; + let end_cond = self + .builder + .fcmpone(end_cond, self.module.type_f64().const_f64(0f64)); + + // Get current basic block. + let loop_end_bb = self.builder.get_insert_block(); + // Add new basic block following the loop. + let after_bb = self.module.append_basic_block(the_function); + + // Register additional incoming value for the loop variable. This will choose the + // updated loop variable if we are iterating in the loop. + variable.add_incoming(next_var, loop_end_bb); + + // Branch depending on the loop end condition. + self.builder.cond_br(end_cond, loop_bb, after_bb); + + self.builder.pos_at_end(after_bb); + + // Restore the shadowed variable if there was one. + if let Some(old_val) = old_val { + // We inserted 'var' above so it must exist. + *named_values.get_mut(var).unwrap() = old_val; + } else { + named_values.remove(var); + } + + // Loops just always return 0. + Ok(self.module.type_f64().const_f64(0f64)) } } } @@ -176,7 +264,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { fn codegen_function( &mut self, FunctionAST(proto, body): &FunctionAST, - named_values: &mut HashMap<&'llvm str, Value<'llvm>>, + named_values: &mut HashMap<String, Value<'llvm>>, ) -> CodegenResult<FnValue<'llvm>> { // Insert the function prototype into the `fn_protos` map to keep track for re-generating // declarations in other modules. @@ -199,7 +287,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { // Update the map with the current functions args. for idx in 0..the_function.args() { let arg = the_function.arg(idx); - named_values.insert(arg.get_name(), arg); + named_values.insert(arg.get_name().into(), arg); } // Codegen function body. diff --git a/src/lexer.rs b/src/lexer.rs index fdab5b4..365b8bf 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -9,6 +9,8 @@ pub enum Token { If, Then, Else, + For, + In, } pub struct Lexer<I> @@ -68,6 +70,8 @@ where "if" => return Token::If, "then" => return Token::Then, "else" => return Token::Else, + "for" => return Token::For, + "in" => return Token::In, _ => {} } @@ -192,4 +196,11 @@ mod test { assert_eq!(Token::Then, lex.gettok()); assert_eq!(Token::Else, lex.gettok()); } + + #[test] + fn test_for() { + let mut lex = Lexer::new("for in".chars()); + assert_eq!(Token::For, lex.gettok()); + assert_eq!(Token::In, lex.gettok()); + } } diff --git a/src/llvm/builder.rs b/src/llvm/builder.rs index 8f581f9..da10231 100644 --- a/src/llvm/builder.rs +++ b/src/llvm/builder.rs @@ -10,7 +10,7 @@ use llvm_sys::{ use std::marker::PhantomData; -use super::{BasicBlock, FnValue, Module, Type, Value}; +use super::{BasicBlock, FnValue, Module, PhiValue, Type, Value}; // Definition of LLVM C API functions using our `repr(transparent)` types. extern "C" { @@ -131,8 +131,8 @@ impl<'llvm> IRBuilder<'llvm> { /// /// Panics if LLVM API returns a `null` pointer. pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { - debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!"); - debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!"); + debug_assert!(lhs.is_f64(), "fcmpult: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fcmpult: Expected f64 as rhs operand!"); let value_ref = unsafe { LLVMBuildFCmp( @@ -140,7 +140,7 @@ impl<'llvm> IRBuilder<'llvm> { LLVMRealPredicate::LLVMRealULT, lhs.value_ref(), rhs.value_ref(), - b"fcmplt\0".as_ptr().cast(), + b"fcmpult\0".as_ptr().cast(), ) }; Value::new(value_ref) @@ -152,8 +152,8 @@ impl<'llvm> IRBuilder<'llvm> { /// /// Panics if LLVM API returns a `null` pointer. pub fn fcmpone(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { - debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!"); - debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!"); + debug_assert!(lhs.is_f64(), "fcmone: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fcmone: Expected f64 as rhs operand!"); let value_ref = unsafe { LLVMBuildFCmp( @@ -251,7 +251,7 @@ impl<'llvm> IRBuilder<'llvm> { &self, phi_type: Type<'llvm>, incoming: &[(Value<'llvm>, BasicBlock<'llvm>)], - ) -> Value<'llvm> { + ) -> PhiValue<'llvm> { let phi_ref = unsafe { LLVMBuildPhi(self.builder, phi_type.type_ref(), b"phi\0".as_ptr().cast()) }; assert!(!phi_ref.is_null()); @@ -268,7 +268,7 @@ impl<'llvm> IRBuilder<'llvm> { } } - Value::new(phi_ref) + PhiValue::new(phi_ref) } } diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index c9f17b6..1eb9c57 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -33,7 +33,7 @@ pub use lljit::{LLJit, ResourceTracker}; pub use module::Module; pub use pass_manager::FunctionPassManager; pub use type_::Type; -pub use value::{FnValue, Value}; +pub use value::{FnValue, PhiValue, Value}; struct Error<'llvm>(&'llvm mut libc::c_char); diff --git a/src/llvm/value.rs b/src/llvm/value.rs index 219c855..912c14e 100644 --- a/src/llvm/value.rs +++ b/src/llvm/value.rs @@ -1,9 +1,11 @@ +#![allow(unused)] + use llvm_sys::{ analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction}, core::{ - LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue, - LLVMGetParam, LLVMGetReturnType, LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2, - LLVMTypeOf, + LLVMAddIncoming, LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams, + LLVMDumpValue, LLVMGetParam, LLVMGetReturnType, LLVMGetValueKind, LLVMGetValueName2, + LLVMIsAFunction, LLVMIsAPHINode, LLVMSetValueName2, LLVMTypeOf, }, prelude::LLVMValueRef, LLVMTypeKind, LLVMValueKind, @@ -43,6 +45,18 @@ impl<'llvm> Value<'llvm> { unsafe { LLVMGetValueKind(self.value_ref()) } } + /// Check if value is `function` type. + pub(super) fn is_function(&self) -> bool { + let cast = unsafe { LLVMIsAFunction(self.value_ref()) }; + !cast.is_null() + } + + /// Check if value is `phinode` type. + pub(super) fn is_phinode(&self) -> bool { + let cast = unsafe { LLVMIsAPHINode(self.value_ref()) }; + !cast.is_null() + } + /// Dump the LLVM Value to stdout. pub fn dump(&self) { unsafe { LLVMDumpValue(self.value_ref()) }; @@ -117,9 +131,8 @@ impl<'llvm> FnValue<'llvm> { /// Panics if `value_ref` is a null pointer. pub(super) fn new(value_ref: LLVMValueRef) -> Self { let value = Value::new(value_ref); - debug_assert_eq!( - value.kind(), - LLVMValueKind::LLVMFunctionValueKind, + debug_assert!( + value.is_function(), "Expected a fn value when constructing FnValue!" ); @@ -175,3 +188,50 @@ impl<'llvm> FnValue<'llvm> { } } } + +/// Wrapper for a LLVM Value Reference specialized for contexts where phi values are needed. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct PhiValue<'llvm>(Value<'llvm>); + +impl<'llvm> Deref for PhiValue<'llvm> { + type Target = Value<'llvm>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'llvm> PhiValue<'llvm> { + /// Create a new PhiValue instance. + /// + /// # Panics + /// + /// Panics if `value_ref` is a null pointer. + pub(super) fn new(value_ref: LLVMValueRef) -> Self { + let value = Value::new(value_ref); + debug_assert!( + value.is_phinode(), + "Expected a phinode value when constructing PhiValue!" + ); + + PhiValue(value) + } + + /// Add an incoming value to the end of a PHI list. + pub fn add_incoming(&self, ival: Value<'llvm>, ibb: BasicBlock<'llvm>) { + debug_assert_eq!( + ival.type_of().kind(), + self.type_of().kind(), + "Type of incoming phi value must be the same as the type used to build the phi node." + ); + + unsafe { + LLVMAddIncoming( + self.value_ref(), + &mut ival.value_ref() as _, + &mut ibb.bb_ref() as _, + 1, + ); + } + } +} diff --git a/src/parser.rs b/src/parser.rs index 39e69ce..3b4fbb2 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -20,6 +20,15 @@ pub enum ExprAST { then: Box<ExprAST>, else_: Box<ExprAST>, }, + + /// ForExprAST - Expression class for for/in. + For { + var: String, + start: Box<ExprAST>, + end: Box<ExprAST>, + step: Option<Box<ExprAST>>, + body: Box<ExprAST>, + }, } /// PrototypeAST - This class represents the "prototype" for a function, @@ -196,6 +205,64 @@ where }) } + /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression + /// + /// Implement `std::unique_ptr<ExprAST> ParseForExpr();` from the tutorial. + fn parse_for_expr(&mut self) -> ParseResult<ExprAST> { + // Consume the 'for' token. + assert_eq!(*self.cur_tok(), Token::For); + self.get_next_token(); + + let var = match self + .parse_identifier_expr() + .map_err(|_| String::from("expected identifier after 'for'"))? + { + ExprAST::Variable(var) => var, + _ => unreachable!(), + }; + + // Consume the '=' token. + if *self.cur_tok() != Token::Char('=') { + return Err("expected '=' after for".into()); + } + self.get_next_token(); + + let start = self.parse_expression()?; + + // Consume the ',' token. + if *self.cur_tok() != Token::Char(',') { + return Err("expected ',' after for start value".into()); + } + self.get_next_token(); + + let end = self.parse_expression()?; + + let step = if *self.cur_tok() == Token::Char(',') { + // Consume the ',' token. + self.get_next_token(); + + Some(self.parse_expression()?) + } else { + None + }; + + // Consume the 'in' token. + if *self.cur_tok() != Token::In { + return Err("expected 'in' after for".into()); + } + self.get_next_token(); + + let body = self.parse_expression()?; + + Ok(ExprAST::For { + var, + start: Box::new(start), + end: Box::new(end), + step: step.map(|s| Box::new(s)), + body: Box::new(body), + }) + } + /// primary /// ::= identifierexpr /// ::= numberexpr @@ -208,6 +275,7 @@ where Token::Number(_) => self.parse_num_expr(), Token::Char('(') => self.parse_paren_expr(), Token::If => self.parse_if_expr(), + Token::For => self.parse_for_expr(), _ => Err("unknown token when expecting an expression".into()), } } @@ -420,8 +488,52 @@ mod test { } #[test] + fn parse_for() { + let mut p = parser("for i = 1, 2, 3 in 4"); + + let var = String::from("i"); + let start = Box::new(ExprAST::Number(1f64)); + let end = Box::new(ExprAST::Number(2f64)); + let step = Some(Box::new(ExprAST::Number(3f64))); + let body = Box::new(ExprAST::Number(4f64)); + + assert_eq!( + p.parse_for_expr(), + Ok(ExprAST::For { + var, + start, + end, + step, + body + }) + ); + } + + #[test] + fn parse_for_no_step() { + let mut p = parser("for i = 1, 2 in 4"); + + let var = String::from("i"); + let start = Box::new(ExprAST::Number(1f64)); + let end = Box::new(ExprAST::Number(2f64)); + let step = None; + let body = Box::new(ExprAST::Number(4f64)); + + assert_eq!( + p.parse_for_expr(), + Ok(ExprAST::For { + var, + start, + end, + step, + body + }) + ); + } + + #[test] fn parse_primary() { - let mut p = parser("1337 foop \n bla(123) \n if a then b else c"); + let mut p = parser("1337 foop \n bla(123) \n if a then b else c \n for x=1,2 in 3"); assert_eq!(p.parse_primary(), Ok(ExprAST::Number(1337f64))); @@ -440,6 +552,17 @@ mod test { else_: Box::new(ExprAST::Variable("c".into())), }) ); + + assert_eq!( + p.parse_primary(), + Ok(ExprAST::For { + var: String::from("x"), + start: Box::new(ExprAST::Number(1f64)), + end: Box::new(ExprAST::Number(2f64)), + step: None, + body: Box::new(ExprAST::Number(3f64)), + }) + ); } #[test] |