aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorJohannes Stoelp <johannes.stoelp@gmail.com>2022-03-27 22:20:06 +0200
committerJohannes Stoelp <johannes.stoelp@gmail.com>2022-03-27 22:20:06 +0200
commit21ea78e57fa480d472d3660881e91813f7b18820 (patch)
tree476b2721ba44d41d9b2faf003701fdc647c9259e /src
parent4f6dd49df3f19204694fcea55f38efd9c5118bf2 (diff)
downloadllvm-kaleidoscope-rs-21ea78e57fa480d472d3660881e91813f7b18820.tar.gz
llvm-kaleidoscope-rs-21ea78e57fa480d472d3660881e91813f7b18820.zip
ch5: added for loop
Diffstat (limited to 'src')
-rw-r--r--src/codegen.rs98
-rw-r--r--src/lexer.rs11
-rw-r--r--src/llvm/builder.rs16
-rw-r--r--src/llvm/mod.rs2
-rw-r--r--src/llvm/value.rs72
-rw-r--r--src/parser.rs125
6 files changed, 303 insertions, 21 deletions
diff --git a/src/codegen.rs b/src/codegen.rs
index 25e8c42..b2d15e7 100644
--- a/src/codegen.rs
+++ b/src/codegen.rs
@@ -38,7 +38,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
fn codegen_expr(
&self,
expr: &ExprAST,
- named_values: &mut HashMap<&'llvm str, Value<'llvm>>,
+ named_values: &mut HashMap<String, Value<'llvm>>,
) -> CodegenResult<Value<'llvm>> {
match expr {
ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)),
@@ -97,7 +97,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
let cond_v = {
// Codgen 'cond' expression.
let v = self.codegen_expr(cond, named_values)?;
- // Convert condition to bool.
+ // Compare 'v' against '0' as 'one = ordered not equal'.
self.builder
.fcmpone(v, self.module.type_f64().const_f64(0f64))
};
@@ -148,7 +148,95 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
&[(then_v, then_bb), (else_v, else_bb)],
);
- Ok(phi)
+ Ok(*phi)
+ }
+ ExprAST::For {
+ var,
+ start,
+ end,
+ step,
+ body,
+ } => {
+ // For 'for' expression we build the following structure.
+ //
+ // entry:
+ // init = start expression
+ // br loop
+ // loop:
+ // i = phi [%init, %entry], [%new_i, %loop]
+ // ; loop body ...
+ // new_i = increment %i by step expression
+ // ; check end condition and branch
+ // end:
+
+ // Compute initial value for the loop variable.
+ let start_val = self.codegen_expr(start, named_values)?;
+
+ let the_function = self.builder.get_insert_block().get_parent();
+ // Get current basic block (used in the loop variable phi node).
+ let entry_bb = self.builder.get_insert_block();
+ // Add new basic block to emit loop body.
+ let loop_bb = self.module.append_basic_block(the_function);
+
+ self.builder.br(loop_bb);
+ self.builder.pos_at_end(loop_bb);
+
+ // Build phi not to pick loop variable in case we come from the 'entry' block.
+ // Which is the case when we enter the loop for the first time.
+ // We will add another incoming value once we computed the updated loop variable
+ // below.
+ let variable = self
+ .builder
+ .phi(self.module.type_f64(), &[(start_val, entry_bb)]);
+
+ // Insert the loop variable into the named values map that it can be referenced
+ // from the body as well as the end condition.
+ // In case the loop variable shadows an existing variable remember the shared one.
+ let old_val = named_values.insert(var.into(), *variable);
+
+ // Generate the loop body.
+ self.codegen_expr(body, named_values)?;
+
+ // Generate step value expression if available else use '1'.
+ let step_val = if let Some(step) = step {
+ self.codegen_expr(step, named_values)?
+ } else {
+ self.module.type_f64().const_f64(1f64)
+ };
+
+ // Increment loop variable.
+ let next_var = self.builder.fadd(*variable, step_val);
+
+ // Generate the loop end condition.
+ let end_cond = self.codegen_expr(end, named_values)?;
+ let end_cond = self
+ .builder
+ .fcmpone(end_cond, self.module.type_f64().const_f64(0f64));
+
+ // Get current basic block.
+ let loop_end_bb = self.builder.get_insert_block();
+ // Add new basic block following the loop.
+ let after_bb = self.module.append_basic_block(the_function);
+
+ // Register additional incoming value for the loop variable. This will choose the
+ // updated loop variable if we are iterating in the loop.
+ variable.add_incoming(next_var, loop_end_bb);
+
+ // Branch depending on the loop end condition.
+ self.builder.cond_br(end_cond, loop_bb, after_bb);
+
+ self.builder.pos_at_end(after_bb);
+
+ // Restore the shadowed variable if there was one.
+ if let Some(old_val) = old_val {
+ // We inserted 'var' above so it must exist.
+ *named_values.get_mut(var).unwrap() = old_val;
+ } else {
+ named_values.remove(var);
+ }
+
+ // Loops just always return 0.
+ Ok(self.module.type_f64().const_f64(0f64))
}
}
}
@@ -176,7 +264,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
fn codegen_function(
&mut self,
FunctionAST(proto, body): &FunctionAST,
- named_values: &mut HashMap<&'llvm str, Value<'llvm>>,
+ named_values: &mut HashMap<String, Value<'llvm>>,
) -> CodegenResult<FnValue<'llvm>> {
// Insert the function prototype into the `fn_protos` map to keep track for re-generating
// declarations in other modules.
@@ -199,7 +287,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> {
// Update the map with the current functions args.
for idx in 0..the_function.args() {
let arg = the_function.arg(idx);
- named_values.insert(arg.get_name(), arg);
+ named_values.insert(arg.get_name().into(), arg);
}
// Codegen function body.
diff --git a/src/lexer.rs b/src/lexer.rs
index fdab5b4..365b8bf 100644
--- a/src/lexer.rs
+++ b/src/lexer.rs
@@ -9,6 +9,8 @@ pub enum Token {
If,
Then,
Else,
+ For,
+ In,
}
pub struct Lexer<I>
@@ -68,6 +70,8 @@ where
"if" => return Token::If,
"then" => return Token::Then,
"else" => return Token::Else,
+ "for" => return Token::For,
+ "in" => return Token::In,
_ => {}
}
@@ -192,4 +196,11 @@ mod test {
assert_eq!(Token::Then, lex.gettok());
assert_eq!(Token::Else, lex.gettok());
}
+
+ #[test]
+ fn test_for() {
+ let mut lex = Lexer::new("for in".chars());
+ assert_eq!(Token::For, lex.gettok());
+ assert_eq!(Token::In, lex.gettok());
+ }
}
diff --git a/src/llvm/builder.rs b/src/llvm/builder.rs
index 8f581f9..da10231 100644
--- a/src/llvm/builder.rs
+++ b/src/llvm/builder.rs
@@ -10,7 +10,7 @@ use llvm_sys::{
use std::marker::PhantomData;
-use super::{BasicBlock, FnValue, Module, Type, Value};
+use super::{BasicBlock, FnValue, Module, PhiValue, Type, Value};
// Definition of LLVM C API functions using our `repr(transparent)` types.
extern "C" {
@@ -131,8 +131,8 @@ impl<'llvm> IRBuilder<'llvm> {
///
/// Panics if LLVM API returns a `null` pointer.
pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
- debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!");
- debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!");
+ debug_assert!(lhs.is_f64(), "fcmpult: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fcmpult: Expected f64 as rhs operand!");
let value_ref = unsafe {
LLVMBuildFCmp(
@@ -140,7 +140,7 @@ impl<'llvm> IRBuilder<'llvm> {
LLVMRealPredicate::LLVMRealULT,
lhs.value_ref(),
rhs.value_ref(),
- b"fcmplt\0".as_ptr().cast(),
+ b"fcmpult\0".as_ptr().cast(),
)
};
Value::new(value_ref)
@@ -152,8 +152,8 @@ impl<'llvm> IRBuilder<'llvm> {
///
/// Panics if LLVM API returns a `null` pointer.
pub fn fcmpone(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
- debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!");
- debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!");
+ debug_assert!(lhs.is_f64(), "fcmone: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fcmone: Expected f64 as rhs operand!");
let value_ref = unsafe {
LLVMBuildFCmp(
@@ -251,7 +251,7 @@ impl<'llvm> IRBuilder<'llvm> {
&self,
phi_type: Type<'llvm>,
incoming: &[(Value<'llvm>, BasicBlock<'llvm>)],
- ) -> Value<'llvm> {
+ ) -> PhiValue<'llvm> {
let phi_ref =
unsafe { LLVMBuildPhi(self.builder, phi_type.type_ref(), b"phi\0".as_ptr().cast()) };
assert!(!phi_ref.is_null());
@@ -268,7 +268,7 @@ impl<'llvm> IRBuilder<'llvm> {
}
}
- Value::new(phi_ref)
+ PhiValue::new(phi_ref)
}
}
diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs
index c9f17b6..1eb9c57 100644
--- a/src/llvm/mod.rs
+++ b/src/llvm/mod.rs
@@ -33,7 +33,7 @@ pub use lljit::{LLJit, ResourceTracker};
pub use module::Module;
pub use pass_manager::FunctionPassManager;
pub use type_::Type;
-pub use value::{FnValue, Value};
+pub use value::{FnValue, PhiValue, Value};
struct Error<'llvm>(&'llvm mut libc::c_char);
diff --git a/src/llvm/value.rs b/src/llvm/value.rs
index 219c855..912c14e 100644
--- a/src/llvm/value.rs
+++ b/src/llvm/value.rs
@@ -1,9 +1,11 @@
+#![allow(unused)]
+
use llvm_sys::{
analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction},
core::{
- LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue,
- LLVMGetParam, LLVMGetReturnType, LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2,
- LLVMTypeOf,
+ LLVMAddIncoming, LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams,
+ LLVMDumpValue, LLVMGetParam, LLVMGetReturnType, LLVMGetValueKind, LLVMGetValueName2,
+ LLVMIsAFunction, LLVMIsAPHINode, LLVMSetValueName2, LLVMTypeOf,
},
prelude::LLVMValueRef,
LLVMTypeKind, LLVMValueKind,
@@ -43,6 +45,18 @@ impl<'llvm> Value<'llvm> {
unsafe { LLVMGetValueKind(self.value_ref()) }
}
+ /// Check if value is `function` type.
+ pub(super) fn is_function(&self) -> bool {
+ let cast = unsafe { LLVMIsAFunction(self.value_ref()) };
+ !cast.is_null()
+ }
+
+ /// Check if value is `phinode` type.
+ pub(super) fn is_phinode(&self) -> bool {
+ let cast = unsafe { LLVMIsAPHINode(self.value_ref()) };
+ !cast.is_null()
+ }
+
/// Dump the LLVM Value to stdout.
pub fn dump(&self) {
unsafe { LLVMDumpValue(self.value_ref()) };
@@ -117,9 +131,8 @@ impl<'llvm> FnValue<'llvm> {
/// Panics if `value_ref` is a null pointer.
pub(super) fn new(value_ref: LLVMValueRef) -> Self {
let value = Value::new(value_ref);
- debug_assert_eq!(
- value.kind(),
- LLVMValueKind::LLVMFunctionValueKind,
+ debug_assert!(
+ value.is_function(),
"Expected a fn value when constructing FnValue!"
);
@@ -175,3 +188,50 @@ impl<'llvm> FnValue<'llvm> {
}
}
}
+
+/// Wrapper for a LLVM Value Reference specialized for contexts where phi values are needed.
+#[derive(Copy, Clone)]
+#[repr(transparent)]
+pub struct PhiValue<'llvm>(Value<'llvm>);
+
+impl<'llvm> Deref for PhiValue<'llvm> {
+ type Target = Value<'llvm>;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<'llvm> PhiValue<'llvm> {
+ /// Create a new PhiValue instance.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `value_ref` is a null pointer.
+ pub(super) fn new(value_ref: LLVMValueRef) -> Self {
+ let value = Value::new(value_ref);
+ debug_assert!(
+ value.is_phinode(),
+ "Expected a phinode value when constructing PhiValue!"
+ );
+
+ PhiValue(value)
+ }
+
+ /// Add an incoming value to the end of a PHI list.
+ pub fn add_incoming(&self, ival: Value<'llvm>, ibb: BasicBlock<'llvm>) {
+ debug_assert_eq!(
+ ival.type_of().kind(),
+ self.type_of().kind(),
+ "Type of incoming phi value must be the same as the type used to build the phi node."
+ );
+
+ unsafe {
+ LLVMAddIncoming(
+ self.value_ref(),
+ &mut ival.value_ref() as _,
+ &mut ibb.bb_ref() as _,
+ 1,
+ );
+ }
+ }
+}
diff --git a/src/parser.rs b/src/parser.rs
index 39e69ce..3b4fbb2 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -20,6 +20,15 @@ pub enum ExprAST {
then: Box<ExprAST>,
else_: Box<ExprAST>,
},
+
+ /// ForExprAST - Expression class for for/in.
+ For {
+ var: String,
+ start: Box<ExprAST>,
+ end: Box<ExprAST>,
+ step: Option<Box<ExprAST>>,
+ body: Box<ExprAST>,
+ },
}
/// PrototypeAST - This class represents the "prototype" for a function,
@@ -196,6 +205,64 @@ where
})
}
+ /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
+ ///
+ /// Implement `std::unique_ptr<ExprAST> ParseForExpr();` from the tutorial.
+ fn parse_for_expr(&mut self) -> ParseResult<ExprAST> {
+ // Consume the 'for' token.
+ assert_eq!(*self.cur_tok(), Token::For);
+ self.get_next_token();
+
+ let var = match self
+ .parse_identifier_expr()
+ .map_err(|_| String::from("expected identifier after 'for'"))?
+ {
+ ExprAST::Variable(var) => var,
+ _ => unreachable!(),
+ };
+
+ // Consume the '=' token.
+ if *self.cur_tok() != Token::Char('=') {
+ return Err("expected '=' after for".into());
+ }
+ self.get_next_token();
+
+ let start = self.parse_expression()?;
+
+ // Consume the ',' token.
+ if *self.cur_tok() != Token::Char(',') {
+ return Err("expected ',' after for start value".into());
+ }
+ self.get_next_token();
+
+ let end = self.parse_expression()?;
+
+ let step = if *self.cur_tok() == Token::Char(',') {
+ // Consume the ',' token.
+ self.get_next_token();
+
+ Some(self.parse_expression()?)
+ } else {
+ None
+ };
+
+ // Consume the 'in' token.
+ if *self.cur_tok() != Token::In {
+ return Err("expected 'in' after for".into());
+ }
+ self.get_next_token();
+
+ let body = self.parse_expression()?;
+
+ Ok(ExprAST::For {
+ var,
+ start: Box::new(start),
+ end: Box::new(end),
+ step: step.map(|s| Box::new(s)),
+ body: Box::new(body),
+ })
+ }
+
/// primary
/// ::= identifierexpr
/// ::= numberexpr
@@ -208,6 +275,7 @@ where
Token::Number(_) => self.parse_num_expr(),
Token::Char('(') => self.parse_paren_expr(),
Token::If => self.parse_if_expr(),
+ Token::For => self.parse_for_expr(),
_ => Err("unknown token when expecting an expression".into()),
}
}
@@ -420,8 +488,52 @@ mod test {
}
#[test]
+ fn parse_for() {
+ let mut p = parser("for i = 1, 2, 3 in 4");
+
+ let var = String::from("i");
+ let start = Box::new(ExprAST::Number(1f64));
+ let end = Box::new(ExprAST::Number(2f64));
+ let step = Some(Box::new(ExprAST::Number(3f64)));
+ let body = Box::new(ExprAST::Number(4f64));
+
+ assert_eq!(
+ p.parse_for_expr(),
+ Ok(ExprAST::For {
+ var,
+ start,
+ end,
+ step,
+ body
+ })
+ );
+ }
+
+ #[test]
+ fn parse_for_no_step() {
+ let mut p = parser("for i = 1, 2 in 4");
+
+ let var = String::from("i");
+ let start = Box::new(ExprAST::Number(1f64));
+ let end = Box::new(ExprAST::Number(2f64));
+ let step = None;
+ let body = Box::new(ExprAST::Number(4f64));
+
+ assert_eq!(
+ p.parse_for_expr(),
+ Ok(ExprAST::For {
+ var,
+ start,
+ end,
+ step,
+ body
+ })
+ );
+ }
+
+ #[test]
fn parse_primary() {
- let mut p = parser("1337 foop \n bla(123) \n if a then b else c");
+ let mut p = parser("1337 foop \n bla(123) \n if a then b else c \n for x=1,2 in 3");
assert_eq!(p.parse_primary(), Ok(ExprAST::Number(1337f64)));
@@ -440,6 +552,17 @@ mod test {
else_: Box::new(ExprAST::Variable("c".into())),
})
);
+
+ assert_eq!(
+ p.parse_primary(),
+ Ok(ExprAST::For {
+ var: String::from("x"),
+ start: Box::new(ExprAST::Number(1f64)),
+ end: Box::new(ExprAST::Number(2f64)),
+ step: None,
+ body: Box::new(ExprAST::Number(3f64)),
+ })
+ );
}
#[test]