diff options
author | Johannes Stoelp <johannes.stoelp@gmail.com> | 2021-10-04 22:51:42 +0200 |
---|---|---|
committer | Johannes Stoelp <johannes.stoelp@gmail.com> | 2021-10-04 22:51:42 +0200 |
commit | 4f6dd49df3f19204694fcea55f38efd9c5118bf2 (patch) | |
tree | 119d4fa88f6f744ecdb1eb9a833c316f5932b0c1 | |
parent | a3dee93989b9fdd99b8a22a2da7f72bcd2ba50c2 (diff) | |
download | llvm-kaleidoscope-rs-4f6dd49df3f19204694fcea55f38efd9c5118bf2.tar.gz llvm-kaleidoscope-rs-4f6dd49df3f19204694fcea55f38efd9c5118bf2.zip |
ch5: added if/then/else
-rw-r--r-- | src/codegen.rs | 72 | ||||
-rw-r--r-- | src/lexer.rs | 14 | ||||
-rw-r--r-- | src/llvm/basic_block.rs | 39 | ||||
-rw-r--r-- | src/llvm/builder.rs | 98 | ||||
-rw-r--r-- | src/llvm/mod.rs | 8 | ||||
-rw-r--r-- | src/llvm/module.rs | 22 | ||||
-rw-r--r-- | src/llvm/value.rs | 13 | ||||
-rw-r--r-- | src/parser.rs | 77 |
8 files changed, 324 insertions, 19 deletions
diff --git a/src/codegen.rs b/src/codegen.rs index 61634ad..25e8c42 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -78,6 +78,78 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { } None => Err("Unknown function referenced".into()), }, + ExprAST::If { cond, then, else_ } => { + // For 'if' expressions we are building the following CFG. + // + // ; cond + // br + // | + // +-----+------+ + // v v + // ; then ; else + // | | + // +-----+------+ + // v + // ; merge + // phi then, else + // ret phi + + let cond_v = { + // Codgen 'cond' expression. + let v = self.codegen_expr(cond, named_values)?; + // Convert condition to bool. + self.builder + .fcmpone(v, self.module.type_f64().const_f64(0f64)) + }; + + // Get the function we are currently inserting into. + let the_function = self.builder.get_insert_block().get_parent(); + + // Create basic blocks for the 'then' / 'else' expressions as well as the return + // instruction ('merge'). + // + // Append the 'then' basic block to the function, don't insert the 'else' and + // 'merge' basic blocks yet. + let then_bb = self.module.append_basic_block(the_function); + let else_bb = self.module.create_basic_block(); + let merge_bb = self.module.create_basic_block(); + + // Create a conditional branch based on the result of the 'cond' expression. + self.builder.cond_br(cond_v, then_bb, else_bb); + + // Move to 'then' basic block and codgen the 'then' expression. + self.builder.pos_at_end(then_bb); + let then_v = self.codegen_expr(then, named_values)?; + // Create unconditional branch to 'merge' block. + self.builder.br(merge_bb); + // Update reference to current basic block (in case the 'then' expression added new + // basic blocks). + let then_bb = self.builder.get_insert_block(); + + // Now append the 'else' basic block to the function. + the_function.append_basic_block(else_bb); + // Move to 'else' basic block and codgen the 'else' expression. + self.builder.pos_at_end(else_bb); + let else_v = self.codegen_expr(else_, named_values)?; + // Create unconditional branch to 'merge' block. + self.builder.br(merge_bb); + // Update reference to current basic block (in case the 'else' expression added new + // basic blocks). + let else_bb = self.builder.get_insert_block(); + + // Now append the 'merge' basic block to the function. + the_function.append_basic_block(merge_bb); + // Move to 'merge' basic block. + self.builder.pos_at_end(merge_bb); + // Codegen the phi node returning the appropriate value depending on the branch + // condition. + let phi = self.builder.phi( + self.module.type_f64(), + &[(then_v, then_bb), (else_v, else_bb)], + ); + + Ok(phi) + } } } diff --git a/src/lexer.rs b/src/lexer.rs index a25f0ab..fdab5b4 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -6,6 +6,9 @@ pub enum Token { Identifier(String), Number(f64), Char(char), + If, + Then, + Else, } pub struct Lexer<I> @@ -62,6 +65,9 @@ where match ident.as_ref() { "def" => return Token::Def, "extern" => return Token::Extern, + "if" => return Token::If, + "then" => return Token::Then, + "else" => return Token::Else, _ => {} } @@ -178,4 +184,12 @@ mod test { assert_eq!(Token::Identifier("c".into()), lex.gettok()); assert_eq!(Token::Eof, lex.gettok()); } + + #[test] + fn test_ite() { + let mut lex = Lexer::new("if then else".chars()); + assert_eq!(Token::If, lex.gettok()); + assert_eq!(Token::Then, lex.gettok()); + assert_eq!(Token::Else, lex.gettok()); + } } diff --git a/src/llvm/basic_block.rs b/src/llvm/basic_block.rs new file mode 100644 index 0000000..e40c7f1 --- /dev/null +++ b/src/llvm/basic_block.rs @@ -0,0 +1,39 @@ +use llvm_sys::{core::LLVMGetBasicBlockParent, prelude::LLVMBasicBlockRef}; + +use std::marker::PhantomData; + +use super::FnValue; + +/// Wrapper for a LLVM Basic Block. +#[derive(Copy, Clone)] +pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); + +impl<'llvm> BasicBlock<'llvm> { + /// Create a new BasicBlock instance. + /// + /// # Panics + /// + /// Panics if `bb_ref` is a null pointer. + pub(super) fn new(bb_ref: LLVMBasicBlockRef) -> BasicBlock<'llvm> { + assert!(!bb_ref.is_null()); + BasicBlock(bb_ref, PhantomData) + } + + /// Get the raw LLVM value reference. + #[inline] + pub(super) fn bb_ref(&self) -> LLVMBasicBlockRef { + self.0 + } + + /// Get the function to which the basic block belongs. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn get_parent(&self) -> FnValue<'llvm> { + let value_ref = unsafe { LLVMGetBasicBlockParent(self.bb_ref()) }; + assert!(!value_ref.is_null()); + + FnValue::new(value_ref) + } +} diff --git a/src/llvm/builder.rs b/src/llvm/builder.rs index 3d43b68..8f581f9 100644 --- a/src/llvm/builder.rs +++ b/src/llvm/builder.rs @@ -1,7 +1,8 @@ use llvm_sys::{ core::{ - LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, LLVMBuildFSub, LLVMBuildRet, LLVMBuildUIToFP, - LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMPositionBuilderAtEnd, + LLVMAddIncoming, LLVMBuildBr, LLVMBuildCondBr, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, + LLVMBuildFSub, LLVMBuildPhi, LLVMBuildRet, LLVMBuildUIToFP, LLVMCreateBuilderInContext, + LLVMDisposeBuilder, LLVMGetInsertBlock, LLVMPositionBuilderAtEnd, }, prelude::{LLVMBuilderRef, LLVMValueRef}, LLVMRealPredicate, @@ -48,10 +49,22 @@ impl<'llvm> IRBuilder<'llvm> { /// Position the IR Builder at the end of the given Basic Block. pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) { unsafe { - LLVMPositionBuilderAtEnd(self.builder, bb.0); + LLVMPositionBuilderAtEnd(self.builder, bb.bb_ref()); } } + /// Get the BasicBlock the IRBuilder currently inputs into. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn get_insert_block(&self) -> BasicBlock<'llvm> { + let bb_ref = unsafe { LLVMGetInsertBlock(self.builder) }; + assert!(!bb_ref.is_null()); + + BasicBlock::new(bb_ref) + } + /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction. /// /// # Panics @@ -112,7 +125,7 @@ impl<'llvm> IRBuilder<'llvm> { Value::new(value_ref) } - /// Emit a [fcmult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. + /// Emit a [fcmpult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. /// /// # Panics /// @@ -133,6 +146,27 @@ impl<'llvm> IRBuilder<'llvm> { Value::new(value_ref) } + /// Emit a [fcmpone](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fcmpone(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFCmp( + self.builder, + LLVMRealPredicate::LLVMRealONE, + lhs.value_ref(), + rhs.value_ref(), + b"fcmpone\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction. /// /// # Panics @@ -180,6 +214,62 @@ impl<'llvm> IRBuilder<'llvm> { let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) }; assert!(!ret.is_null()); } + + /// Emit an unconditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn br(&self, dest: BasicBlock<'llvm>) { + let br_ref = unsafe { LLVMBuildBr(self.builder, dest.bb_ref()) }; + assert!(!br_ref.is_null()); + } + + /// Emit a conditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn cond_br(&self, cond: Value<'llvm>, then: BasicBlock<'llvm>, else_: BasicBlock<'llvm>) { + let br_ref = unsafe { + LLVMBuildCondBr( + self.builder, + cond.value_ref(), + then.bb_ref(), + else_.bb_ref(), + ) + }; + assert!(!br_ref.is_null()); + } + + /// Emit a [phi](https://llvm.org/docs/LangRef.html#phi-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn phi( + &self, + phi_type: Type<'llvm>, + incoming: &[(Value<'llvm>, BasicBlock<'llvm>)], + ) -> Value<'llvm> { + let phi_ref = + unsafe { LLVMBuildPhi(self.builder, phi_type.type_ref(), b"phi\0".as_ptr().cast()) }; + assert!(!phi_ref.is_null()); + + for (val, bb) in incoming { + debug_assert_eq!( + val.type_of().kind(), + phi_type.kind(), + "Type of incoming phi value must be the same as the type used to build the phi node." + ); + + unsafe { + LLVMAddIncoming(phi_ref, &mut val.value_ref() as _, &mut bb.bb_ref() as _, 1); + } + } + + Value::new(phi_ref) + } } impl Drop for IRBuilder<'_> { diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 16e6bfd..c9f17b6 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -11,7 +11,6 @@ use llvm_sys::{ core::LLVMShutdown, error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage}, - prelude::LLVMBasicBlockRef, target::{ LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter, LLVM_InitializeNativeTarget, @@ -19,8 +18,8 @@ use llvm_sys::{ }; use std::ffi::CStr; -use std::marker::PhantomData; +mod basic_block; mod builder; mod lljit; mod module; @@ -28,6 +27,7 @@ mod pass_manager; mod type_; mod value; +pub use basic_block::BasicBlock; pub use builder::IRBuilder; pub use lljit::{LLJit, ResourceTracker}; pub use module::Module; @@ -35,10 +35,6 @@ pub use pass_manager::FunctionPassManager; pub use type_::Type; pub use value::{FnValue, Value}; -/// Wrapper for a LLVM Basic Block. -#[derive(Copy, Clone)] -pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); - struct Error<'llvm>(&'llvm mut libc::c_char); impl<'llvm> Error<'llvm> { diff --git a/src/llvm/module.rs b/src/llvm/module.rs index d737b8e..21d85f5 100644 --- a/src/llvm/module.rs +++ b/src/llvm/module.rs @@ -1,7 +1,8 @@ use llvm_sys::{ core::{ - LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMDisposeModule, LLVMDoubleTypeInContext, - LLVMDumpModule, LLVMGetNamedFunction, LLVMModuleCreateWithNameInContext, + LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMCreateBasicBlockInContext, + LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMGetNamedFunction, + LLVMModuleCreateWithNameInContext, }, orc2::{ LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule, @@ -13,7 +14,6 @@ use llvm_sys::{ }; use std::convert::TryFrom; -use std::marker::PhantomData; use super::{BasicBlock, FnValue, Type}; use crate::SmallCStr; @@ -176,7 +176,21 @@ impl<'llvm> Module { }; assert!(!block.is_null()); - BasicBlock(block, PhantomData) + BasicBlock::new(block) + } + + /// Create a free-standing Basic Block without adding it to a function. + /// This can be added to a function at a later point in time with + /// [`FnValue::append_basic_block`]. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn create_basic_block(&self) -> BasicBlock<'llvm> { + let block = unsafe { LLVMCreateBasicBlockInContext(self.ctx, b"block\0".as_ptr().cast()) }; + assert!(!block.is_null()); + + BasicBlock::new(block) } } diff --git a/src/llvm/value.rs b/src/llvm/value.rs index 9b79c69..219c855 100644 --- a/src/llvm/value.rs +++ b/src/llvm/value.rs @@ -1,8 +1,9 @@ use llvm_sys::{ analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction}, core::{ - LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue, LLVMGetParam, LLVMGetReturnType, - LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2, LLVMTypeOf, + LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue, + LLVMGetParam, LLVMGetReturnType, LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2, + LLVMTypeOf, }, prelude::LLVMValueRef, LLVMTypeKind, LLVMValueKind, @@ -12,6 +13,7 @@ use std::ffi::CStr; use std::marker::PhantomData; use std::ops::Deref; +use super::BasicBlock; use super::Type; /// Wrapper for a LLVM Value Reference. @@ -156,6 +158,13 @@ impl<'llvm> FnValue<'llvm> { unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize } } + /// Append a Basic Block to the end of the function value. + pub fn append_basic_block(&self, bb: BasicBlock<'llvm>) { + unsafe { + LLVMAppendExistingBasicBlock(self.value_ref(), bb.bb_ref()); + } + } + /// Verify that the given function is valid. pub fn verify(&self) -> bool { unsafe { diff --git a/src/parser.rs b/src/parser.rs index b3a26cb..39e69ce 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -13,6 +13,13 @@ pub enum ExprAST { /// Call - Expression class for function calls. Call(String, Vec<ExprAST>), + + /// If - Expression class for if/then/else. + If { + cond: Box<ExprAST>, + then: Box<ExprAST>, + else_: Box<ExprAST>, + }, } /// PrototypeAST - This class represents the "prototype" for a function, @@ -137,8 +144,6 @@ where args.push(arg); if *self.cur_tok() == Token::Char(')') { - // Eat ')' token. - self.get_next_token(); break; } @@ -150,10 +155,47 @@ where } } + assert_eq!(*self.cur_tok(), Token::Char(')')); + // Eat ')' token. + self.get_next_token(); + Ok(ExprAST::Call(id_name, args)) } } + /// ifexpr ::= 'if' expression 'then' expression 'else' expression + /// + /// Implement `std::unique_ptr<ExprAST> ParseIfExpr();` from the tutorial. + fn parse_if_expr(&mut self) -> ParseResult<ExprAST> { + // Consume 'if' token. + assert_eq!(*self.cur_tok(), Token::If); + self.get_next_token(); + + let cond = self.parse_expression()?; + + if *dbg!(self.cur_tok()) != Token::Then { + return Err("Expected 'then'".into()); + } + // Consume 'then' token. + self.get_next_token(); + + let then = self.parse_expression()?; + + if *self.cur_tok() != Token::Else { + return Err("Expected 'else'".into()); + } + // Consume 'else' token. + self.get_next_token(); + + let else_ = self.parse_expression()?; + + Ok(ExprAST::If { + cond: Box::new(cond), + then: Box::new(then), + else_: Box::new(else_), + }) + } + /// primary /// ::= identifierexpr /// ::= numberexpr @@ -165,6 +207,7 @@ where Token::Identifier(_) => self.parse_identifier_expr(), Token::Number(_) => self.parse_num_expr(), Token::Char('(') => self.parse_paren_expr(), + Token::If => self.parse_if_expr(), _ => Err("unknown token when expecting an expression".into()), } } @@ -358,8 +401,27 @@ mod test { } #[test] + fn parse_if() { + let mut p = parser("if 1 then 2 else 3"); + + let cond = Box::new(ExprAST::Number(1f64)); + let then = Box::new(ExprAST::Number(2f64)); + let else_ = Box::new(ExprAST::Number(3f64)); + + assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ })); + + let mut p = parser("if foo() then bar(2) else baz(3)"); + + let cond = Box::new(ExprAST::Call("foo".into(), vec![])); + let then = Box::new(ExprAST::Call("bar".into(), vec![ExprAST::Number(2f64)])); + let else_ = Box::new(ExprAST::Call("baz".into(), vec![ExprAST::Number(3f64)])); + + assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ })); + } + + #[test] fn parse_primary() { - let mut p = parser("1337 foop \n bla(123)"); + let mut p = parser("1337 foop \n bla(123) \n if a then b else c"); assert_eq!(p.parse_primary(), Ok(ExprAST::Number(1337f64))); @@ -369,6 +431,15 @@ mod test { p.parse_primary(), Ok(ExprAST::Call("bla".into(), vec![ExprAST::Number(123f64)])) ); + + assert_eq!( + p.parse_primary(), + Ok(ExprAST::If { + cond: Box::new(ExprAST::Variable("a".into())), + then: Box::new(ExprAST::Variable("b".into())), + else_: Box::new(ExprAST::Variable("c".into())), + }) + ); } #[test] |