aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorJohannes Stoelp <johannes.stoelp@gmail.com>2021-10-04 22:51:42 +0200
committerJohannes Stoelp <johannes.stoelp@gmail.com>2021-10-04 22:51:42 +0200
commit4f6dd49df3f19204694fcea55f38efd9c5118bf2 (patch)
tree119d4fa88f6f744ecdb1eb9a833c316f5932b0c1 /src
parenta3dee93989b9fdd99b8a22a2da7f72bcd2ba50c2 (diff)
downloadllvm-kaleidoscope-rs-4f6dd49df3f19204694fcea55f38efd9c5118bf2.tar.gz
llvm-kaleidoscope-rs-4f6dd49df3f19204694fcea55f38efd9c5118bf2.zip
ch5: added if/then/else
Diffstat (limited to 'src')
-rw-r--r--src/codegen.rs72
-rw-r--r--src/lexer.rs14
-rw-r--r--src/llvm/basic_block.rs39
-rw-r--r--src/llvm/builder.rs98
-rw-r--r--src/llvm/mod.rs8
-rw-r--r--src/llvm/module.rs22
-rw-r--r--src/llvm/value.rs13
-rw-r--r--src/parser.rs77
8 files changed, 324 insertions, 19 deletions
diff --git a/src/codegen.rs b/src/codegen.rs
index 61634ad..25e8c42 100644
--- a/src/codegen.rs
+++ b/src/codegen.rs
@@ -78,6 +78,78 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
}
None => Err("Unknown function referenced".into()),
},
+ ExprAST::If { cond, then, else_ } => {
+ // For 'if' expressions we are building the following CFG.
+ //
+ // ; cond
+ // br
+ // |
+ // +-----+------+
+ // v v
+ // ; then ; else
+ // | |
+ // +-----+------+
+ // v
+ // ; merge
+ // phi then, else
+ // ret phi
+
+ let cond_v = {
+ // Codgen 'cond' expression.
+ let v = self.codegen_expr(cond, named_values)?;
+ // Convert condition to bool.
+ self.builder
+ .fcmpone(v, self.module.type_f64().const_f64(0f64))
+ };
+
+ // Get the function we are currently inserting into.
+ let the_function = self.builder.get_insert_block().get_parent();
+
+ // Create basic blocks for the 'then' / 'else' expressions as well as the return
+ // instruction ('merge').
+ //
+ // Append the 'then' basic block to the function, don't insert the 'else' and
+ // 'merge' basic blocks yet.
+ let then_bb = self.module.append_basic_block(the_function);
+ let else_bb = self.module.create_basic_block();
+ let merge_bb = self.module.create_basic_block();
+
+ // Create a conditional branch based on the result of the 'cond' expression.
+ self.builder.cond_br(cond_v, then_bb, else_bb);
+
+ // Move to 'then' basic block and codgen the 'then' expression.
+ self.builder.pos_at_end(then_bb);
+ let then_v = self.codegen_expr(then, named_values)?;
+ // Create unconditional branch to 'merge' block.
+ self.builder.br(merge_bb);
+ // Update reference to current basic block (in case the 'then' expression added new
+ // basic blocks).
+ let then_bb = self.builder.get_insert_block();
+
+ // Now append the 'else' basic block to the function.
+ the_function.append_basic_block(else_bb);
+ // Move to 'else' basic block and codgen the 'else' expression.
+ self.builder.pos_at_end(else_bb);
+ let else_v = self.codegen_expr(else_, named_values)?;
+ // Create unconditional branch to 'merge' block.
+ self.builder.br(merge_bb);
+ // Update reference to current basic block (in case the 'else' expression added new
+ // basic blocks).
+ let else_bb = self.builder.get_insert_block();
+
+ // Now append the 'merge' basic block to the function.
+ the_function.append_basic_block(merge_bb);
+ // Move to 'merge' basic block.
+ self.builder.pos_at_end(merge_bb);
+ // Codegen the phi node returning the appropriate value depending on the branch
+ // condition.
+ let phi = self.builder.phi(
+ self.module.type_f64(),
+ &[(then_v, then_bb), (else_v, else_bb)],
+ );
+
+ Ok(phi)
+ }
}
}
diff --git a/src/lexer.rs b/src/lexer.rs
index a25f0ab..fdab5b4 100644
--- a/src/lexer.rs
+++ b/src/lexer.rs
@@ -6,6 +6,9 @@ pub enum Token {
Identifier(String),
Number(f64),
Char(char),
+ If,
+ Then,
+ Else,
}
pub struct Lexer<I>
@@ -62,6 +65,9 @@ where
match ident.as_ref() {
"def" => return Token::Def,
"extern" => return Token::Extern,
+ "if" => return Token::If,
+ "then" => return Token::Then,
+ "else" => return Token::Else,
_ => {}
}
@@ -178,4 +184,12 @@ mod test {
assert_eq!(Token::Identifier("c".into()), lex.gettok());
assert_eq!(Token::Eof, lex.gettok());
}
+
+ #[test]
+ fn test_ite() {
+ let mut lex = Lexer::new("if then else".chars());
+ assert_eq!(Token::If, lex.gettok());
+ assert_eq!(Token::Then, lex.gettok());
+ assert_eq!(Token::Else, lex.gettok());
+ }
}
diff --git a/src/llvm/basic_block.rs b/src/llvm/basic_block.rs
new file mode 100644
index 0000000..e40c7f1
--- /dev/null
+++ b/src/llvm/basic_block.rs
@@ -0,0 +1,39 @@
+use llvm_sys::{core::LLVMGetBasicBlockParent, prelude::LLVMBasicBlockRef};
+
+use std::marker::PhantomData;
+
+use super::FnValue;
+
+/// Wrapper for a LLVM Basic Block.
+#[derive(Copy, Clone)]
+pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>);
+
+impl<'llvm> BasicBlock<'llvm> {
+ /// Create a new BasicBlock instance.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `bb_ref` is a null pointer.
+ pub(super) fn new(bb_ref: LLVMBasicBlockRef) -> BasicBlock<'llvm> {
+ assert!(!bb_ref.is_null());
+ BasicBlock(bb_ref, PhantomData)
+ }
+
+ /// Get the raw LLVM value reference.
+ #[inline]
+ pub(super) fn bb_ref(&self) -> LLVMBasicBlockRef {
+ self.0
+ }
+
+ /// Get the function to which the basic block belongs.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn get_parent(&self) -> FnValue<'llvm> {
+ let value_ref = unsafe { LLVMGetBasicBlockParent(self.bb_ref()) };
+ assert!(!value_ref.is_null());
+
+ FnValue::new(value_ref)
+ }
+}
diff --git a/src/llvm/builder.rs b/src/llvm/builder.rs
index 3d43b68..8f581f9 100644
--- a/src/llvm/builder.rs
+++ b/src/llvm/builder.rs
@@ -1,7 +1,8 @@
use llvm_sys::{
core::{
- LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, LLVMBuildFSub, LLVMBuildRet, LLVMBuildUIToFP,
- LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMPositionBuilderAtEnd,
+ LLVMAddIncoming, LLVMBuildBr, LLVMBuildCondBr, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul,
+ LLVMBuildFSub, LLVMBuildPhi, LLVMBuildRet, LLVMBuildUIToFP, LLVMCreateBuilderInContext,
+ LLVMDisposeBuilder, LLVMGetInsertBlock, LLVMPositionBuilderAtEnd,
},
prelude::{LLVMBuilderRef, LLVMValueRef},
LLVMRealPredicate,
@@ -48,10 +49,22 @@ impl<'llvm> IRBuilder<'llvm> {
/// Position the IR Builder at the end of the given Basic Block.
pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) {
unsafe {
- LLVMPositionBuilderAtEnd(self.builder, bb.0);
+ LLVMPositionBuilderAtEnd(self.builder, bb.bb_ref());
}
}
+ /// Get the BasicBlock the IRBuilder currently inputs into.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn get_insert_block(&self) -> BasicBlock<'llvm> {
+ let bb_ref = unsafe { LLVMGetInsertBlock(self.builder) };
+ assert!(!bb_ref.is_null());
+
+ BasicBlock::new(bb_ref)
+ }
+
/// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction.
///
/// # Panics
@@ -112,7 +125,7 @@ impl<'llvm> IRBuilder<'llvm> {
Value::new(value_ref)
}
- /// Emit a [fcmult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction.
+ /// Emit a [fcmpult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction.
///
/// # Panics
///
@@ -133,6 +146,27 @@ impl<'llvm> IRBuilder<'llvm> {
Value::new(value_ref)
}
+ /// Emit a [fcmpone](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn fcmpone(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
+ debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!");
+
+ let value_ref = unsafe {
+ LLVMBuildFCmp(
+ self.builder,
+ LLVMRealPredicate::LLVMRealONE,
+ lhs.value_ref(),
+ rhs.value_ref(),
+ b"fcmpone\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
/// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction.
///
/// # Panics
@@ -180,6 +214,62 @@ impl<'llvm> IRBuilder<'llvm> {
let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) };
assert!(!ret.is_null());
}
+
+ /// Emit an unconditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn br(&self, dest: BasicBlock<'llvm>) {
+ let br_ref = unsafe { LLVMBuildBr(self.builder, dest.bb_ref()) };
+ assert!(!br_ref.is_null());
+ }
+
+ /// Emit a conditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn cond_br(&self, cond: Value<'llvm>, then: BasicBlock<'llvm>, else_: BasicBlock<'llvm>) {
+ let br_ref = unsafe {
+ LLVMBuildCondBr(
+ self.builder,
+ cond.value_ref(),
+ then.bb_ref(),
+ else_.bb_ref(),
+ )
+ };
+ assert!(!br_ref.is_null());
+ }
+
+ /// Emit a [phi](https://llvm.org/docs/LangRef.html#phi-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn phi(
+ &self,
+ phi_type: Type<'llvm>,
+ incoming: &[(Value<'llvm>, BasicBlock<'llvm>)],
+ ) -> Value<'llvm> {
+ let phi_ref =
+ unsafe { LLVMBuildPhi(self.builder, phi_type.type_ref(), b"phi\0".as_ptr().cast()) };
+ assert!(!phi_ref.is_null());
+
+ for (val, bb) in incoming {
+ debug_assert_eq!(
+ val.type_of().kind(),
+ phi_type.kind(),
+ "Type of incoming phi value must be the same as the type used to build the phi node."
+ );
+
+ unsafe {
+ LLVMAddIncoming(phi_ref, &mut val.value_ref() as _, &mut bb.bb_ref() as _, 1);
+ }
+ }
+
+ Value::new(phi_ref)
+ }
}
impl Drop for IRBuilder<'_> {
diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs
index 16e6bfd..c9f17b6 100644
--- a/src/llvm/mod.rs
+++ b/src/llvm/mod.rs
@@ -11,7 +11,6 @@
use llvm_sys::{
core::LLVMShutdown,
error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage},
- prelude::LLVMBasicBlockRef,
target::{
LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter,
LLVM_InitializeNativeTarget,
@@ -19,8 +18,8 @@ use llvm_sys::{
};
use std::ffi::CStr;
-use std::marker::PhantomData;
+mod basic_block;
mod builder;
mod lljit;
mod module;
@@ -28,6 +27,7 @@ mod pass_manager;
mod type_;
mod value;
+pub use basic_block::BasicBlock;
pub use builder::IRBuilder;
pub use lljit::{LLJit, ResourceTracker};
pub use module::Module;
@@ -35,10 +35,6 @@ pub use pass_manager::FunctionPassManager;
pub use type_::Type;
pub use value::{FnValue, Value};
-/// Wrapper for a LLVM Basic Block.
-#[derive(Copy, Clone)]
-pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>);
-
struct Error<'llvm>(&'llvm mut libc::c_char);
impl<'llvm> Error<'llvm> {
diff --git a/src/llvm/module.rs b/src/llvm/module.rs
index d737b8e..21d85f5 100644
--- a/src/llvm/module.rs
+++ b/src/llvm/module.rs
@@ -1,7 +1,8 @@
use llvm_sys::{
core::{
- LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMDisposeModule, LLVMDoubleTypeInContext,
- LLVMDumpModule, LLVMGetNamedFunction, LLVMModuleCreateWithNameInContext,
+ LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMCreateBasicBlockInContext,
+ LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMGetNamedFunction,
+ LLVMModuleCreateWithNameInContext,
},
orc2::{
LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule,
@@ -13,7 +14,6 @@ use llvm_sys::{
};
use std::convert::TryFrom;
-use std::marker::PhantomData;
use super::{BasicBlock, FnValue, Type};
use crate::SmallCStr;
@@ -176,7 +176,21 @@ impl<'llvm> Module {
};
assert!(!block.is_null());
- BasicBlock(block, PhantomData)
+ BasicBlock::new(block)
+ }
+
+ /// Create a free-standing Basic Block without adding it to a function.
+ /// This can be added to a function at a later point in time with
+ /// [`FnValue::append_basic_block`].
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn create_basic_block(&self) -> BasicBlock<'llvm> {
+ let block = unsafe { LLVMCreateBasicBlockInContext(self.ctx, b"block\0".as_ptr().cast()) };
+ assert!(!block.is_null());
+
+ BasicBlock::new(block)
}
}
diff --git a/src/llvm/value.rs b/src/llvm/value.rs
index 9b79c69..219c855 100644
--- a/src/llvm/value.rs
+++ b/src/llvm/value.rs
@@ -1,8 +1,9 @@
use llvm_sys::{
analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction},
core::{
- LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue, LLVMGetParam, LLVMGetReturnType,
- LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2, LLVMTypeOf,
+ LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue,
+ LLVMGetParam, LLVMGetReturnType, LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2,
+ LLVMTypeOf,
},
prelude::LLVMValueRef,
LLVMTypeKind, LLVMValueKind,
@@ -12,6 +13,7 @@ use std::ffi::CStr;
use std::marker::PhantomData;
use std::ops::Deref;
+use super::BasicBlock;
use super::Type;
/// Wrapper for a LLVM Value Reference.
@@ -156,6 +158,13 @@ impl<'llvm> FnValue<'llvm> {
unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize }
}
+ /// Append a Basic Block to the end of the function value.
+ pub fn append_basic_block(&self, bb: BasicBlock<'llvm>) {
+ unsafe {
+ LLVMAppendExistingBasicBlock(self.value_ref(), bb.bb_ref());
+ }
+ }
+
/// Verify that the given function is valid.
pub fn verify(&self) -> bool {
unsafe {
diff --git a/src/parser.rs b/src/parser.rs
index b3a26cb..39e69ce 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -13,6 +13,13 @@ pub enum ExprAST {
/// Call - Expression class for function calls.
Call(String, Vec<ExprAST>),
+
+ /// If - Expression class for if/then/else.
+ If {
+ cond: Box<ExprAST>,
+ then: Box<ExprAST>,
+ else_: Box<ExprAST>,
+ },
}
/// PrototypeAST - This class represents the "prototype" for a function,
@@ -137,8 +144,6 @@ where
args.push(arg);
if *self.cur_tok() == Token::Char(')') {
- // Eat ')' token.
- self.get_next_token();
break;
}
@@ -150,10 +155,47 @@ where
}
}
+ assert_eq!(*self.cur_tok(), Token::Char(')'));
+ // Eat ')' token.
+ self.get_next_token();
+
Ok(ExprAST::Call(id_name, args))
}
}
+ /// ifexpr ::= 'if' expression 'then' expression 'else' expression
+ ///
+ /// Implement `std::unique_ptr<ExprAST> ParseIfExpr();` from the tutorial.
+ fn parse_if_expr(&mut self) -> ParseResult<ExprAST> {
+ // Consume 'if' token.
+ assert_eq!(*self.cur_tok(), Token::If);
+ self.get_next_token();
+
+ let cond = self.parse_expression()?;
+
+ if *dbg!(self.cur_tok()) != Token::Then {
+ return Err("Expected 'then'".into());
+ }
+ // Consume 'then' token.
+ self.get_next_token();
+
+ let then = self.parse_expression()?;
+
+ if *self.cur_tok() != Token::Else {
+ return Err("Expected 'else'".into());
+ }
+ // Consume 'else' token.
+ self.get_next_token();
+
+ let else_ = self.parse_expression()?;
+
+ Ok(ExprAST::If {
+ cond: Box::new(cond),
+ then: Box::new(then),
+ else_: Box::new(else_),
+ })
+ }
+
/// primary
/// ::= identifierexpr
/// ::= numberexpr
@@ -165,6 +207,7 @@ where
Token::Identifier(_) => self.parse_identifier_expr(),
Token::Number(_) => self.parse_num_expr(),
Token::Char('(') => self.parse_paren_expr(),
+ Token::If => self.parse_if_expr(),
_ => Err("unknown token when expecting an expression".into()),
}
}
@@ -358,8 +401,27 @@ mod test {
}
#[test]
+ fn parse_if() {
+ let mut p = parser("if 1 then 2 else 3");
+
+ let cond = Box::new(ExprAST::Number(1f64));
+ let then = Box::new(ExprAST::Number(2f64));
+ let else_ = Box::new(ExprAST::Number(3f64));
+
+ assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ }));
+
+ let mut p = parser("if foo() then bar(2) else baz(3)");
+
+ let cond = Box::new(ExprAST::Call("foo".into(), vec![]));
+ let then = Box::new(ExprAST::Call("bar".into(), vec![ExprAST::Number(2f64)]));
+ let else_ = Box::new(ExprAST::Call("baz".into(), vec![ExprAST::Number(3f64)]));
+
+ assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ }));
+ }
+
+ #[test]
fn parse_primary() {
- let mut p = parser("1337 foop \n bla(123)");
+ let mut p = parser("1337 foop \n bla(123) \n if a then b else c");
assert_eq!(p.parse_primary(), Ok(ExprAST::Number(1337f64)));
@@ -369,6 +431,15 @@ mod test {
p.parse_primary(),
Ok(ExprAST::Call("bla".into(), vec![ExprAST::Number(123f64)]))
);
+
+ assert_eq!(
+ p.parse_primary(),
+ Ok(ExprAST::If {
+ cond: Box::new(ExprAST::Variable("a".into())),
+ then: Box::new(ExprAST::Variable("b".into())),
+ else_: Box::new(ExprAST::Variable("c".into())),
+ })
+ );
}
#[test]