aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-14 00:19:40 +0200
committerJohannes Stoelp <johannes.stoelp@gmail.com>2021-09-14 00:19:40 +0200
commit9e6c0a92dbedb5b8801772802e2e5d2e56cb9bcf (patch)
tree96e798b847138a2edffadad02ef6ec74cfd8c739 /src
parent96e9dd5f4ae46b5705b8063a43bb8576e1e5b7b0 (diff)
downloadllvm-kaleidoscope-rs-chapter3.tar.gz
llvm-kaleidoscope-rs-chapter3.zip
ch3: added LLVM IR code genchapter3
- Added safe wrapper around LLVM C API - Added codegen module to emit LLVM IR for the AST - Update the main repl loop to codegen LLVM IR
Diffstat (limited to 'src')
-rw-r--r--src/codegen.rs135
-rw-r--r--src/lib.rs109
-rw-r--r--src/llvm.rs534
-rw-r--r--src/main.rs53
-rw-r--r--src/parser.rs5
5 files changed, 821 insertions, 15 deletions
diff --git a/src/codegen.rs b/src/codegen.rs
new file mode 100644
index 0000000..866aab7
--- /dev/null
+++ b/src/codegen.rs
@@ -0,0 +1,135 @@
+use std::collections::HashMap;
+
+use crate::llvm::{Builder, FnValue, Module, Value};
+use crate::parser::{ExprAST, FunctionAST, PrototypeAST};
+use crate::Either;
+
+type CodegenResult<T> = Result<T, String>;
+
+/// Code generator from kaleidoscope AST to LLVM IR.
+pub struct Codegen<'llvm, 'a> {
+ module: &'llvm Module,
+ builder: &'a Builder<'llvm>,
+}
+
+impl<'llvm, 'a> Codegen<'llvm, 'a> {
+ /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`.
+ pub fn compile(
+ module: &'llvm Module,
+ compilee: Either<&PrototypeAST, &FunctionAST>,
+ ) -> CodegenResult<FnValue<'llvm>> {
+ let cg = Codegen {
+ module,
+ builder: &Builder::with_ctx(module),
+ };
+ let mut variables = HashMap::new();
+
+ match compilee {
+ Either::A(proto) => Ok(cg.codegen_prototype(proto)),
+ Either::B(func) => cg.codegen_function(func, &mut variables),
+ }
+ }
+
+ fn codegen_expr(
+ &self,
+ expr: &ExprAST,
+ named_values: &mut HashMap<&'llvm str, Value<'llvm>>,
+ ) -> CodegenResult<Value<'llvm>> {
+ match expr {
+ ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)),
+ ExprAST::Variable(name) => match named_values.get(name.as_str()) {
+ Some(value) => Ok(*value),
+ None => Err("Unknown variable name".into()),
+ },
+ ExprAST::Binary(binop, lhs, rhs) => {
+ let l = self.codegen_expr(lhs, named_values)?;
+ let r = self.codegen_expr(rhs, named_values)?;
+
+ match binop {
+ '+' => Ok(self.builder.fadd(l, r)),
+ '-' => Ok(self.builder.fsub(l, r)),
+ '*' => Ok(self.builder.fmul(l, r)),
+ '<' => {
+ let res = self.builder.fcmpult(l, r);
+ // Turn bool into f64.
+ Ok(self.builder.uitofp(res, self.module.type_f64()))
+ }
+ _ => Err("invalid binary operator".into()),
+ }
+ }
+ ExprAST::Call(callee, args) => match self.module.get_fn(callee) {
+ Some(callee) => {
+ if callee.args() != args.len() {
+ return Err("Incorrect # arguments passed".into());
+ }
+
+ // Generate code for function argument expressions.
+ let mut args: Vec<Value<'_>> = args
+ .iter()
+ .map(|arg| self.codegen_expr(arg, named_values))
+ .collect::<CodegenResult<_>>()?;
+
+ Ok(self.builder.call(callee, &mut args))
+ }
+ None => Err("Unknown function referenced".into()),
+ },
+ }
+ }
+
+ fn codegen_prototype(&self, PrototypeAST(name, args): &PrototypeAST) -> FnValue<'llvm> {
+ let type_f64 = self.module.type_f64();
+
+ let mut doubles = Vec::new();
+ doubles.resize(args.len(), type_f64);
+
+ // Build the function type: fn(f64, f64, ..) -> f64
+ let ft = self.module.type_fn(&mut doubles, type_f64);
+
+ // Create the function declaration.
+ let f = self.module.add_fn(name, ft);
+
+ // Set the names of the function arguments.
+ for idx in 0..f.args() {
+ f.arg(idx).set_name(&args[idx]);
+ }
+
+ f
+ }
+
+ fn codegen_function(
+ &self,
+ FunctionAST(proto, body): &FunctionAST,
+ named_values: &mut HashMap<&'llvm str, Value<'llvm>>,
+ ) -> CodegenResult<FnValue<'llvm>> {
+ let the_function = match self.module.get_fn(&proto.0) {
+ Some(f) => f,
+ None => self.codegen_prototype(proto),
+ };
+
+ if the_function.basic_blocks() > 0 {
+ return Err("Function cannot be redefined.".into());
+ }
+
+ // Create entry basic block to insert code.
+ let bb = self.module.append_basic_block(the_function);
+ self.builder.pos_at_end(bb);
+
+ // New scope, clear the map with the function args.
+ named_values.clear();
+
+ // Update the map with the current functions args.
+ for idx in 0..the_function.args() {
+ let arg = the_function.arg(idx);
+ named_values.insert(arg.get_name(), arg);
+ }
+
+ // Codegen function body.
+ if let Ok(ret) = self.codegen_expr(body, named_values) {
+ self.builder.ret(ret);
+ assert!(the_function.verify());
+ Ok(the_function)
+ } else {
+ todo!("Failed to codegen function body, erase from module!");
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..64721a7
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,109 @@
+use std::convert::TryFrom;
+
+pub mod codegen;
+pub mod lexer;
+pub mod llvm;
+pub mod parser;
+
+/// Fixed size of [`SmallCStr`] including the trailing `\0` byte.
+pub const SMALL_STR_SIZE: usize = 16;
+
+/// Small C string on the stack with fixed size [`SMALL_STR_SIZE`].
+#[derive(Debug, PartialEq)]
+pub struct SmallCStr([u8; SMALL_STR_SIZE]);
+
+impl SmallCStr {
+ /// Create a new C string from `src`.
+ /// Returns [`None`] if `src` exceeds the fixed size or contains any `\0` bytes.
+ pub fn new<T: AsRef<[u8]>>(src: &T) -> Option<SmallCStr> {
+ let src = src.as_ref();
+ let len = src.len();
+
+ // Check for \0 bytes.
+ let contains_null = unsafe { !libc::memchr(src.as_ptr().cast(), 0, len).is_null() };
+
+ if contains_null || len > SMALL_STR_SIZE - 1 {
+ None
+ } else {
+ let mut dest = [0; SMALL_STR_SIZE];
+ dest[..len].copy_from_slice(src);
+ Some(SmallCStr(dest))
+ }
+ }
+
+ /// Return pointer to C string.
+ pub const fn as_ptr(&self) -> *const libc::c_char {
+ self.0.as_ptr().cast()
+ }
+}
+
+impl TryFrom<&str> for SmallCStr {
+ type Error = ();
+
+ fn try_from(value: &str) -> Result<Self, Self::Error> {
+ SmallCStr::new(&value).ok_or(())
+ }
+}
+
+/// Either type, for APIs accepting two types.
+pub enum Either<A, B> {
+ A(A),
+ B(B),
+}
+
+#[cfg(test)]
+mod test {
+ use super::{SmallCStr, SMALL_STR_SIZE};
+ use std::convert::TryInto;
+
+ #[test]
+ fn test_create() {
+ let src = "\x30\x31\x32\x33";
+ let scs = SmallCStr::new(&src).unwrap();
+ assert_eq!(&scs.0[..5], &[0x30, 0x31, 0x32, 0x33, 0x00]);
+
+ let src = b"abcd1234";
+ let scs = SmallCStr::new(&src).unwrap();
+ assert_eq!(
+ &scs.0[..9],
+ &[0x61, 0x62, 0x63, 0x64, 0x31, 0x32, 0x33, 0x34, 0x00]
+ );
+ }
+
+ #[test]
+ fn test_contain_null() {
+ let src = "\x30\x00\x32\x33";
+ let scs = SmallCStr::new(&src);
+ assert_eq!(scs, None);
+
+ let src = "\x30\x31\x32\x33\x00";
+ let scs = SmallCStr::new(&src);
+ assert_eq!(scs, None);
+ }
+
+ #[test]
+ fn test_too_large() {
+ let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::<String>();
+ let scs = SmallCStr::new(&src);
+ assert_eq!(scs, None);
+
+ let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::<String>();
+ let scs = SmallCStr::new(&src);
+ assert_eq!(scs, None);
+ }
+
+ #[test]
+ fn test_try_into() {
+ let src = "\x30\x31\x32\x33";
+ let scs: Result<SmallCStr, ()> = src.try_into();
+ assert!(scs.is_ok());
+
+ let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::<String>();
+ let scs: Result<SmallCStr, ()> = src.as_str().try_into();
+ assert!(scs.is_err());
+
+ let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::<String>();
+ let scs: Result<SmallCStr, ()> = src.as_str().try_into();
+ assert!(scs.is_err());
+ }
+}
diff --git a/src/llvm.rs b/src/llvm.rs
new file mode 100644
index 0000000..ed6b930
--- /dev/null
+++ b/src/llvm.rs
@@ -0,0 +1,534 @@
+//! Safe wrapper around the LLVM C API.
+//!
+//! References returned from the LLVM API are tied to the `'llvm` lifetime which is bound to the
+//! context where the objects are created in.
+//! We do not offer wrappers to remove or delete any objects in the context and therefore all the
+//! references will be valid for the liftime of the context.
+
+use llvm_sys::analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction};
+use llvm_sys::core::{
+ LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul,
+ LLVMBuildFSub, LLVMBuildRet, LLVMBuildUIToFP, LLVMConstReal, LLVMContextCreate,
+ LLVMContextDispose, LLVMCountBasicBlocks, LLVMCountParams, LLVMCreateBuilderInContext,
+ LLVMDisposeBuilder, LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMDumpType,
+ LLVMDumpValue, LLVMGetNamedFunction, LLVMGetParam, LLVMGetReturnType, LLVMGetTypeKind,
+ LLVMGetValueKind, LLVMGetValueName2, LLVMModuleCreateWithNameInContext,
+ LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf,
+};
+use llvm_sys::prelude::{
+ LLVMBasicBlockRef, LLVMBool, LLVMBuilderRef, LLVMContextRef, LLVMModuleRef, LLVMTypeRef,
+ LLVMValueRef,
+};
+use llvm_sys::{LLVMRealPredicate, LLVMTypeKind, LLVMValueKind};
+
+use std::convert::TryFrom;
+use std::ffi::CStr;
+use std::marker::PhantomData;
+use std::ops::Deref;
+
+use crate::SmallCStr;
+
+// Definition of LLVM C API functions using our `repr(transparent)` types.
+extern "C" {
+ fn LLVMFunctionType(
+ ReturnType: Type<'_>,
+ ParamTypes: *mut Type<'_>,
+ ParamCount: ::libc::c_uint,
+ IsVarArg: LLVMBool,
+ ) -> LLVMTypeRef;
+ fn LLVMBuildCall2(
+ arg1: LLVMBuilderRef,
+ arg2: Type<'_>,
+ Fn: FnValue<'_>,
+ Args: *mut Value<'_>,
+ NumArgs: ::libc::c_uint,
+ Name: *const ::libc::c_char,
+ ) -> LLVMValueRef;
+}
+
+// ====================
+// Module / Context
+// ====================
+
+/// Wrapper for a LLVM Module with its own LLVM Context.
+pub struct Module {
+ ctx: LLVMContextRef,
+ module: LLVMModuleRef,
+}
+
+impl<'llvm> Module {
+ /// Create a new Module instance.
+ ///
+ /// # Panics
+ ///
+ /// Panics if creating the context or the module fails.
+ pub fn new() -> Self {
+ let (ctx, module) = unsafe {
+ let c = LLVMContextCreate();
+ let m = LLVMModuleCreateWithNameInContext(b"module\0".as_ptr().cast(), c);
+ assert!(!c.is_null() && !m.is_null());
+ (c, m)
+ };
+
+ Module { ctx, module }
+ }
+
+ /// Dump LLVM IR emitted into the Module to stdout.
+ pub fn dump(&self) {
+ unsafe { LLVMDumpModule(self.module) };
+ }
+
+ /// Get a type reference representing a `f64` float.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn type_f64(&self) -> Type<'llvm> {
+ let type_ref = unsafe { LLVMDoubleTypeInContext(self.ctx) };
+ Type::new(type_ref)
+ }
+
+ /// Get a type reference representing a `fn(args) -> ret` function.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn type_fn(&'llvm self, args: &mut [Type<'llvm>], ret: Type<'llvm>) -> Type<'llvm> {
+ let type_ref = unsafe {
+ LLVMFunctionType(
+ ret,
+ args.as_mut_ptr(),
+ args.len() as libc::c_uint,
+ 0, /* IsVarArg */
+ )
+ };
+ Type::new(type_ref)
+ }
+
+ /// Add a function with the given `name` and `fn_type` to the module and return a value
+ /// reference representing the function.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer or `name` could not be converted to a
+ /// [`SmallCStr`].
+ pub fn add_fn(&'llvm self, name: &str, fn_type: Type<'llvm>) -> FnValue<'llvm> {
+ debug_assert_eq!(
+ fn_type.kind(),
+ LLVMTypeKind::LLVMFunctionTypeKind,
+ "Expected a function type when adding a function!"
+ );
+
+ let name = SmallCStr::try_from(name)
+ .expect("Failed to convert 'name' argument to small C string!");
+
+ let value_ref = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type.0) };
+ FnValue::new(value_ref)
+ }
+
+ /// Get a function value reference to the function with the given `name` if it was previously
+ /// added to the module with [`add_fn`][Module::add_fn].
+ ///
+ /// # Panics
+ ///
+ /// Panics if `name` could not be converted to a [`SmallCStr`].
+ pub fn get_fn(&'llvm self, name: &str) -> Option<FnValue<'llvm>> {
+ let name = SmallCStr::try_from(name)
+ .expect("Failed to convert 'name' argument to small C string!");
+
+ let value_ref = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
+
+ (!value_ref.is_null()).then(|| FnValue::new(value_ref))
+ }
+
+ /// Append a Basic Block to the end of the function referenced by the value reference
+ /// `fn_value`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn append_basic_block(&'llvm self, fn_value: FnValue<'llvm>) -> BasicBlock<'llvm> {
+ let block = unsafe {
+ LLVMAppendBasicBlockInContext(
+ self.ctx,
+ fn_value.value_ref(),
+ b"block\0".as_ptr().cast(),
+ )
+ };
+ assert!(!block.is_null());
+
+ BasicBlock(block, PhantomData)
+ }
+}
+
+impl Drop for Module {
+ fn drop(&mut self) {
+ unsafe {
+ LLVMDisposeModule(self.module);
+ LLVMContextDispose(self.ctx);
+ }
+ }
+}
+
+// ===========
+// Builder
+// ===========
+
+/// Wrapper for a LLVM IR Builder.
+pub struct Builder<'llvm> {
+ builder: LLVMBuilderRef,
+ _ctx: PhantomData<&'llvm ()>,
+}
+
+impl<'llvm> Builder<'llvm> {
+ /// Create a new LLVM IR Builder with the `module`s context.
+ ///
+ /// # Panics
+ ///
+ /// Panics if creating the IR Builder fails.
+ pub fn with_ctx(module: &'llvm Module) -> Builder<'llvm> {
+ let builder = unsafe { LLVMCreateBuilderInContext(module.ctx) };
+ assert!(!builder.is_null());
+
+ Builder {
+ builder,
+ _ctx: PhantomData,
+ }
+ }
+
+ /// Position the IR Builder at the end of the given Basic Block.
+ pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) {
+ unsafe {
+ LLVMPositionBuilderAtEnd(self.builder, bb.0);
+ }
+ }
+
+ /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn fadd(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
+ debug_assert!(lhs.is_f64(), "fadd: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fadd: Expected f64 as rhs operand!");
+
+ let value_ref = unsafe {
+ LLVMBuildFAdd(
+ self.builder,
+ lhs.value_ref(),
+ rhs.value_ref(),
+ b"fadd\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
+ /// Emit a [fsub](https://llvm.org/docs/LangRef.html#fsub-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn fsub(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
+ debug_assert!(lhs.is_f64(), "fsub: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fsub: Expected f64 as rhs operand!");
+
+ let value_ref = unsafe {
+ LLVMBuildFSub(
+ self.builder,
+ lhs.value_ref(),
+ rhs.value_ref(),
+ b"fsub\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
+ /// Emit a [fmul](https://llvm.org/docs/LangRef.html#fmul-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn fmul(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
+ debug_assert!(lhs.is_f64(), "fmul: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fmul: Expected f64 as rhs operand!");
+
+ let value_ref = unsafe {
+ LLVMBuildFMul(
+ self.builder,
+ lhs.value_ref(),
+ rhs.value_ref(),
+ b"fmul\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
+ /// Emit a [fcmult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
+ debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!");
+ debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!");
+
+ let value_ref = unsafe {
+ LLVMBuildFCmp(
+ self.builder,
+ LLVMRealPredicate::LLVMRealULT,
+ lhs.value_ref(),
+ rhs.value_ref(),
+ b"fcmplt\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
+ /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn uitofp(&self, val: Value<'llvm>, dest_type: Type<'llvm>) -> Value<'llvm> {
+ debug_assert!(val.is_int(), "uitofp: Expected integer operand!");
+
+ let value_ref = unsafe {
+ LLVMBuildUIToFP(
+ self.builder,
+ val.value_ref(),
+ dest_type.0,
+ b"uitofp\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
+ /// Emit a [call](https://llvm.org/docs/LangRef.html#call-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn call(&self, fn_value: FnValue<'llvm>, args: &mut [Value<'llvm>]) -> Value<'llvm> {
+ let value_ref = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ fn_value.ret_type(),
+ fn_value,
+ args.as_mut_ptr(),
+ args.len() as libc::c_uint,
+ b"call\0".as_ptr().cast(),
+ )
+ };
+ Value::new(value_ref)
+ }
+
+ /// Emit a [ret](https://llvm.org/docs/LangRef.html#ret-instruction) instruction.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn ret(&self, ret: Value<'llvm>) {
+ let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) };
+ assert!(!ret.is_null());
+ }
+}
+
+impl Drop for Builder<'_> {
+ fn drop(&mut self) {
+ unsafe { LLVMDisposeBuilder(self.builder) }
+ }
+}
+
+// ==============
+// BasicBlock
+// ==============
+
+/// Wrapper for a LLVM Basic Block.
+#[derive(Copy, Clone)]
+pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>);
+
+// ========
+// Type
+// ========
+
+/// Wrapper for a LLVM Type Reference.
+#[derive(Copy, Clone)]
+#[repr(transparent)]
+pub struct Type<'llvm>(LLVMTypeRef, PhantomData<&'llvm ()>);
+
+impl<'llvm> Type<'llvm> {
+ fn new(type_ref: LLVMTypeRef) -> Self {
+ assert!(!type_ref.is_null());
+ Type(type_ref, PhantomData)
+ }
+
+ fn kind(&self) -> LLVMTypeKind {
+ unsafe { LLVMGetTypeKind(self.0) }
+ }
+
+ /// Dump the LLVM Type to stdout.
+ pub fn dump(&self) {
+ unsafe { LLVMDumpType(self.0) };
+ }
+
+ /// Get a value reference representing the const `f64` value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn const_f64(self, n: f64) -> Value<'llvm> {
+ debug_assert_eq!(
+ self.kind(),
+ LLVMTypeKind::LLVMDoubleTypeKind,
+ "Expected a double type when creating const f64 value!"
+ );
+
+ let value_ref = unsafe { LLVMConstReal(self.0, n) };
+ Value::new(value_ref)
+ }
+}
+
+// =========
+// Value
+// =========
+
+/// Wrapper for a LLVM Value Reference.
+#[derive(Copy, Clone)]
+#[repr(transparent)]
+pub struct Value<'llvm>(LLVMValueRef, PhantomData<&'llvm ()>);
+
+impl<'llvm> Value<'llvm> {
+ fn new(value_ref: LLVMValueRef) -> Self {
+ assert!(!value_ref.is_null());
+ Value(value_ref, PhantomData)
+ }
+
+ #[inline]
+ fn value_ref(&self) -> LLVMValueRef {
+ self.0
+ }
+
+ fn kind(&self) -> LLVMValueKind {
+ unsafe { LLVMGetValueKind(self.value_ref()) }
+ }
+
+ /// Dump the LLVM Value to stdout.
+ pub fn dump(&self) {
+ unsafe { LLVMDumpValue(self.value_ref()) };
+ }
+
+ /// Get a type reference representing for the given value reference.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn type_of(&self) -> Type<'llvm> {
+ let type_ref = unsafe { LLVMTypeOf(self.value_ref()) };
+ Type::new(type_ref)
+ }
+
+ /// Set the name for the given value reference.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn set_name(&self, name: &str) {
+ unsafe { LLVMSetValueName2(self.value_ref(), name.as_ptr().cast(), name.len()) };
+ }
+
+ /// Get the name for the given value reference.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn get_name(&self) -> &'llvm str {
+ let name = unsafe {
+ let mut len: libc::size_t = 0;
+ let name = LLVMGetValueName2(self.0, &mut len as _);
+ assert!(!name.is_null());
+
+ CStr::from_ptr(name)
+ };
+
+ // TODO: Does this string live for the time of the LLVM context?!
+ name.to_str()
+ .expect("Expected valid UTF8 string from LLVM API")
+ }
+
+ /// Check if value is of `f64` type.
+ pub fn is_f64(&self) -> bool {
+ self.type_of().kind() == LLVMTypeKind::LLVMDoubleTypeKind
+ }
+
+ /// Check if value is of integer type.
+ pub fn is_int(&self) -> bool {
+ self.type_of().kind() == LLVMTypeKind::LLVMIntegerTypeKind
+ }
+}
+
+/// Wrapper for a LLVM Value Reference specialized for contexts where function values are needed.
+#[derive(Copy, Clone)]
+#[repr(transparent)]
+pub struct FnValue<'llvm>(Value<'llvm>);
+
+impl<'llvm> Deref for FnValue<'llvm> {
+ type Target = Value<'llvm>;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<'llvm> FnValue<'llvm> {
+ fn new(value_ref: LLVMValueRef) -> Self {
+ let value = Value::new(value_ref);
+ debug_assert_eq!(
+ value.kind(),
+ LLVMValueKind::LLVMFunctionValueKind,
+ "Expected a fn value when constructing FnValue!"
+ );
+
+ FnValue(value)
+ }
+
+ /// Get a type reference representing the return value of the given function value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer.
+ pub fn ret_type(&self) -> Type<'llvm> {
+ let type_ref = unsafe { LLVMGetReturnType(LLVMTypeOf(self.value_ref())) };
+ Type::new(type_ref)
+ }
+
+ /// Get the number of function arguments for the given function value.
+ pub fn args(&self) -> usize {
+ unsafe { LLVMCountParams(self.value_ref()) as usize }
+ }
+
+ /// Get a value reference for the function argument at index `idx`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if LLVM API returns a `null` pointer or indexed out of bounds.
+ pub fn arg(&self, idx: usize) -> Value<'llvm> {
+ assert!(idx < self.args());
+
+ let value_ref = unsafe { LLVMGetParam(self.value_ref(), idx as libc::c_uint) };
+ Value::new(value_ref)
+ }
+
+ /// Get the number of Basic Blocks for the given function value.
+ pub fn basic_blocks(&self) -> usize {
+ unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize }
+ }
+
+ /// Verify that the given function is valid.
+ pub fn verify(&self) -> bool {
+ unsafe {
+ LLVMVerifyFunction(
+ self.value_ref(),
+ LLVMVerifierFailureAction::LLVMPrintMessageAction,
+ ) == 0
+ }
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 2646873..7160e04 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,16 +1,24 @@
-mod lexer;
-mod parser;
+use llvm_kaleidoscope_rs::{
+ codegen::Codegen,
+ lexer::{Lexer, Token},
+ llvm,
+ parser::Parser,
+ Either,
+};
-use lexer::{Lexer, Token};
-use parser::Parser;
use std::io::Read;
-fn handle_definition<I>(p: &mut Parser<I>)
+fn handle_definition<I>(p: &mut Parser<I>, module: &llvm::Module)
where
I: Iterator<Item = char>,
{
match p.parse_definition() {
- Ok(expr) => println!("Parse 'def'\n{:?}", expr),
+ Ok(func) => {
+ println!("Parse 'def'\n{:?}", func);
+ if let Ok(func) = Codegen::compile(module, Either::B(&func)) {
+ func.dump();
+ }
+ }
Err(err) => {
eprintln!("Error: {:?}", err);
p.get_next_token();
@@ -18,12 +26,17 @@ where
}
}
-fn handle_extern<I>(p: &mut Parser<I>)
+fn handle_extern<I>(p: &mut Parser<I>, module: &llvm::Module)
where
I: Iterator<Item = char>,
{
match p.parse_extern() {
- Ok(expr) => println!("Parse 'extern'\n{:?}", expr),
+ Ok(proto) => {
+ println!("Parse 'extern'\n{:?}", proto);
+ if let Ok(proto) = Codegen::compile(module, Either::A(&proto)) {
+ proto.dump();
+ }
+ }
Err(err) => {
eprintln!("Error: {:?}", err);
p.get_next_token();
@@ -31,12 +44,17 @@ where
}
}
-fn handle_top_level_expression<I>(p: &mut Parser<I>)
+fn handle_top_level_expression<I>(p: &mut Parser<I>, module: &llvm::Module)
where
I: Iterator<Item = char>,
{
match p.parse_top_level_expr() {
- Ok(expr) => println!("Parse top-level expression\n{:?}", expr),
+ Ok(func) => {
+ println!("Parse top-level expression\n{:?}", func);
+ if let Ok(func) = Codegen::compile(module, Either::B(&func)) {
+ func.dump();
+ }
+ }
Err(err) => {
eprintln!("Error: {:?}", err);
p.get_next_token();
@@ -49,16 +67,22 @@ fn main() {
println!("ENTER to parse current input.");
println!("C-d to exit.");
+ // Create lexer over stdin.
let lexer = Lexer::new(std::io::stdin().bytes().filter_map(|v| {
let v = v.ok()?;
Some(v.into())
}));
+ // Create parser for kaleidoscope.
let mut parser = Parser::new(lexer);
// Throw first coin and initialize cur_tok.
parser.get_next_token();
+ // Initialize LLVM module with its own context.
+ // We will emit LLVM IR into this module.
+ let module = llvm::Module::new();
+
loop {
match *parser.cur_tok() {
Token::Eof => break,
@@ -66,9 +90,12 @@ fn main() {
// Ignore top-level semicolon.
parser.get_next_token()
}
- Token::Def => handle_definition(&mut parser),
- Token::Extern => handle_extern(&mut parser),
- _ => handle_top_level_expression(&mut parser),
+ Token::Def => handle_definition(&mut parser, &module),
+ Token::Extern => handle_extern(&mut parser, &module),
+ _ => handle_top_level_expression(&mut parser, &module),
}
}
+
+ // Dump all the emitted LLVM IR to stdout.
+ module.dump();
}
diff --git a/src/parser.rs b/src/parser.rs
index 63f5a77..af69a87 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -19,15 +19,16 @@ pub enum ExprAST {
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
#[derive(Debug, PartialEq)]
-pub struct PrototypeAST(String, Vec<String>);
+pub struct PrototypeAST(pub String, pub Vec<String>);
/// FunctionAST - This class represents a function definition itself.
#[derive(Debug, PartialEq)]
-pub struct FunctionAST(PrototypeAST, ExprAST);
+pub struct FunctionAST(pub PrototypeAST, pub ExprAST);
/// Parse result with String as Error type (to be compliant with tutorial).
type ParseResult<T> = Result<T, String>;
+/// Parser for the `kaleidoscope` language.
pub struct Parser<I>
where
I: Iterator<Item = char>,