From 9e6c0a92dbedb5b8801772802e2e5d2e56cb9bcf Mon Sep 17 00:00:00 2001 From: Johannes Stoelp Date: Tue, 14 Sep 2021 00:19:40 +0200 Subject: ch3: added LLVM IR code gen - Added safe wrapper around LLVM C API - Added codegen module to emit LLVM IR for the AST - Update the main repl loop to codegen LLVM IR --- Cargo.lock | 100 +++++++++++ Cargo.toml | 2 + src/codegen.rs | 135 +++++++++++++++ src/lib.rs | 109 ++++++++++++ src/llvm.rs | 534 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 53 ++++-- src/parser.rs | 5 +- 7 files changed, 923 insertions(+), 15 deletions(-) create mode 100644 src/codegen.rs create mode 100644 src/lib.rs create mode 100644 src/llvm.rs diff --git a/Cargo.lock b/Cargo.lock index dd2a11a..ece31e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,106 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +dependencies = [ + "memchr", +] + +[[package]] +name = "cc" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26a6ce4b6a484fa3edb70f7efa6fc430fd2b87285fe8b84304fd0936faa0dc0" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" + [[package]] name = "llvm-kaleidoscope-rs" version = "0.1.0" +dependencies = [ + "libc", + "llvm-sys", +] + +[[package]] +name = "llvm-sys" +version = "120.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4a810627ac62b396f5fd2214ba9bbd8748d4d6efdc4d2c1c1303ea7a75763ce" +dependencies = [ + "cc", + "lazy_static", + "libc", + "regex", + "semver", +] + +[[package]] +name = "memchr" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" + +[[package]] +name = "pest" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53" +dependencies = [ + "ucd-trie", +] + +[[package]] +name = "regex" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" + +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser", +] + +[[package]] +name = "semver-parser" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" +dependencies = [ + "pest", +] + +[[package]] +name = "ucd-trie" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" diff --git a/Cargo.toml b/Cargo.toml index e345bab..9bcbe8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,5 @@ version = "0.1.0" edition = "2018" [dependencies] +libc = "0.2" +llvm-sys = {version = "120.1", features = ["strict-versioning"]} diff --git a/src/codegen.rs b/src/codegen.rs new file mode 100644 index 0000000..866aab7 --- /dev/null +++ b/src/codegen.rs @@ -0,0 +1,135 @@ +use std::collections::HashMap; + +use crate::llvm::{Builder, FnValue, Module, Value}; +use crate::parser::{ExprAST, FunctionAST, PrototypeAST}; +use crate::Either; + +type CodegenResult = Result; + +/// Code generator from kaleidoscope AST to LLVM IR. +pub struct Codegen<'llvm, 'a> { + module: &'llvm Module, + builder: &'a Builder<'llvm>, +} + +impl<'llvm, 'a> Codegen<'llvm, 'a> { + /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`. + pub fn compile( + module: &'llvm Module, + compilee: Either<&PrototypeAST, &FunctionAST>, + ) -> CodegenResult> { + let cg = Codegen { + module, + builder: &Builder::with_ctx(module), + }; + let mut variables = HashMap::new(); + + match compilee { + Either::A(proto) => Ok(cg.codegen_prototype(proto)), + Either::B(func) => cg.codegen_function(func, &mut variables), + } + } + + fn codegen_expr( + &self, + expr: &ExprAST, + named_values: &mut HashMap<&'llvm str, Value<'llvm>>, + ) -> CodegenResult> { + match expr { + ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)), + ExprAST::Variable(name) => match named_values.get(name.as_str()) { + Some(value) => Ok(*value), + None => Err("Unknown variable name".into()), + }, + ExprAST::Binary(binop, lhs, rhs) => { + let l = self.codegen_expr(lhs, named_values)?; + let r = self.codegen_expr(rhs, named_values)?; + + match binop { + '+' => Ok(self.builder.fadd(l, r)), + '-' => Ok(self.builder.fsub(l, r)), + '*' => Ok(self.builder.fmul(l, r)), + '<' => { + let res = self.builder.fcmpult(l, r); + // Turn bool into f64. + Ok(self.builder.uitofp(res, self.module.type_f64())) + } + _ => Err("invalid binary operator".into()), + } + } + ExprAST::Call(callee, args) => match self.module.get_fn(callee) { + Some(callee) => { + if callee.args() != args.len() { + return Err("Incorrect # arguments passed".into()); + } + + // Generate code for function argument expressions. + let mut args: Vec> = args + .iter() + .map(|arg| self.codegen_expr(arg, named_values)) + .collect::>()?; + + Ok(self.builder.call(callee, &mut args)) + } + None => Err("Unknown function referenced".into()), + }, + } + } + + fn codegen_prototype(&self, PrototypeAST(name, args): &PrototypeAST) -> FnValue<'llvm> { + let type_f64 = self.module.type_f64(); + + let mut doubles = Vec::new(); + doubles.resize(args.len(), type_f64); + + // Build the function type: fn(f64, f64, ..) -> f64 + let ft = self.module.type_fn(&mut doubles, type_f64); + + // Create the function declaration. + let f = self.module.add_fn(name, ft); + + // Set the names of the function arguments. + for idx in 0..f.args() { + f.arg(idx).set_name(&args[idx]); + } + + f + } + + fn codegen_function( + &self, + FunctionAST(proto, body): &FunctionAST, + named_values: &mut HashMap<&'llvm str, Value<'llvm>>, + ) -> CodegenResult> { + let the_function = match self.module.get_fn(&proto.0) { + Some(f) => f, + None => self.codegen_prototype(proto), + }; + + if the_function.basic_blocks() > 0 { + return Err("Function cannot be redefined.".into()); + } + + // Create entry basic block to insert code. + let bb = self.module.append_basic_block(the_function); + self.builder.pos_at_end(bb); + + // New scope, clear the map with the function args. + named_values.clear(); + + // Update the map with the current functions args. + for idx in 0..the_function.args() { + let arg = the_function.arg(idx); + named_values.insert(arg.get_name(), arg); + } + + // Codegen function body. + if let Ok(ret) = self.codegen_expr(body, named_values) { + self.builder.ret(ret); + assert!(the_function.verify()); + Ok(the_function) + } else { + todo!("Failed to codegen function body, erase from module!"); + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..64721a7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,109 @@ +use std::convert::TryFrom; + +pub mod codegen; +pub mod lexer; +pub mod llvm; +pub mod parser; + +/// Fixed size of [`SmallCStr`] including the trailing `\0` byte. +pub const SMALL_STR_SIZE: usize = 16; + +/// Small C string on the stack with fixed size [`SMALL_STR_SIZE`]. +#[derive(Debug, PartialEq)] +pub struct SmallCStr([u8; SMALL_STR_SIZE]); + +impl SmallCStr { + /// Create a new C string from `src`. + /// Returns [`None`] if `src` exceeds the fixed size or contains any `\0` bytes. + pub fn new>(src: &T) -> Option { + let src = src.as_ref(); + let len = src.len(); + + // Check for \0 bytes. + let contains_null = unsafe { !libc::memchr(src.as_ptr().cast(), 0, len).is_null() }; + + if contains_null || len > SMALL_STR_SIZE - 1 { + None + } else { + let mut dest = [0; SMALL_STR_SIZE]; + dest[..len].copy_from_slice(src); + Some(SmallCStr(dest)) + } + } + + /// Return pointer to C string. + pub const fn as_ptr(&self) -> *const libc::c_char { + self.0.as_ptr().cast() + } +} + +impl TryFrom<&str> for SmallCStr { + type Error = (); + + fn try_from(value: &str) -> Result { + SmallCStr::new(&value).ok_or(()) + } +} + +/// Either type, for APIs accepting two types. +pub enum Either { + A(A), + B(B), +} + +#[cfg(test)] +mod test { + use super::{SmallCStr, SMALL_STR_SIZE}; + use std::convert::TryInto; + + #[test] + fn test_create() { + let src = "\x30\x31\x32\x33"; + let scs = SmallCStr::new(&src).unwrap(); + assert_eq!(&scs.0[..5], &[0x30, 0x31, 0x32, 0x33, 0x00]); + + let src = b"abcd1234"; + let scs = SmallCStr::new(&src).unwrap(); + assert_eq!( + &scs.0[..9], + &[0x61, 0x62, 0x63, 0x64, 0x31, 0x32, 0x33, 0x34, 0x00] + ); + } + + #[test] + fn test_contain_null() { + let src = "\x30\x00\x32\x33"; + let scs = SmallCStr::new(&src); + assert_eq!(scs, None); + + let src = "\x30\x31\x32\x33\x00"; + let scs = SmallCStr::new(&src); + assert_eq!(scs, None); + } + + #[test] + fn test_too_large() { + let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::(); + let scs = SmallCStr::new(&src); + assert_eq!(scs, None); + + let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::(); + let scs = SmallCStr::new(&src); + assert_eq!(scs, None); + } + + #[test] + fn test_try_into() { + let src = "\x30\x31\x32\x33"; + let scs: Result = src.try_into(); + assert!(scs.is_ok()); + + let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::(); + let scs: Result = src.as_str().try_into(); + assert!(scs.is_err()); + + let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::(); + let scs: Result = src.as_str().try_into(); + assert!(scs.is_err()); + } +} diff --git a/src/llvm.rs b/src/llvm.rs new file mode 100644 index 0000000..ed6b930 --- /dev/null +++ b/src/llvm.rs @@ -0,0 +1,534 @@ +//! Safe wrapper around the LLVM C API. +//! +//! References returned from the LLVM API are tied to the `'llvm` lifetime which is bound to the +//! context where the objects are created in. +//! We do not offer wrappers to remove or delete any objects in the context and therefore all the +//! references will be valid for the liftime of the context. + +use llvm_sys::analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction}; +use llvm_sys::core::{ + LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, + LLVMBuildFSub, LLVMBuildRet, LLVMBuildUIToFP, LLVMConstReal, LLVMContextCreate, + LLVMContextDispose, LLVMCountBasicBlocks, LLVMCountParams, LLVMCreateBuilderInContext, + LLVMDisposeBuilder, LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMDumpType, + LLVMDumpValue, LLVMGetNamedFunction, LLVMGetParam, LLVMGetReturnType, LLVMGetTypeKind, + LLVMGetValueKind, LLVMGetValueName2, LLVMModuleCreateWithNameInContext, + LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, +}; +use llvm_sys::prelude::{ + LLVMBasicBlockRef, LLVMBool, LLVMBuilderRef, LLVMContextRef, LLVMModuleRef, LLVMTypeRef, + LLVMValueRef, +}; +use llvm_sys::{LLVMRealPredicate, LLVMTypeKind, LLVMValueKind}; + +use std::convert::TryFrom; +use std::ffi::CStr; +use std::marker::PhantomData; +use std::ops::Deref; + +use crate::SmallCStr; + +// Definition of LLVM C API functions using our `repr(transparent)` types. +extern "C" { + fn LLVMFunctionType( + ReturnType: Type<'_>, + ParamTypes: *mut Type<'_>, + ParamCount: ::libc::c_uint, + IsVarArg: LLVMBool, + ) -> LLVMTypeRef; + fn LLVMBuildCall2( + arg1: LLVMBuilderRef, + arg2: Type<'_>, + Fn: FnValue<'_>, + Args: *mut Value<'_>, + NumArgs: ::libc::c_uint, + Name: *const ::libc::c_char, + ) -> LLVMValueRef; +} + +// ==================== +// Module / Context +// ==================== + +/// Wrapper for a LLVM Module with its own LLVM Context. +pub struct Module { + ctx: LLVMContextRef, + module: LLVMModuleRef, +} + +impl<'llvm> Module { + /// Create a new Module instance. + /// + /// # Panics + /// + /// Panics if creating the context or the module fails. + pub fn new() -> Self { + let (ctx, module) = unsafe { + let c = LLVMContextCreate(); + let m = LLVMModuleCreateWithNameInContext(b"module\0".as_ptr().cast(), c); + assert!(!c.is_null() && !m.is_null()); + (c, m) + }; + + Module { ctx, module } + } + + /// Dump LLVM IR emitted into the Module to stdout. + pub fn dump(&self) { + unsafe { LLVMDumpModule(self.module) }; + } + + /// Get a type reference representing a `f64` float. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn type_f64(&self) -> Type<'llvm> { + let type_ref = unsafe { LLVMDoubleTypeInContext(self.ctx) }; + Type::new(type_ref) + } + + /// Get a type reference representing a `fn(args) -> ret` function. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn type_fn(&'llvm self, args: &mut [Type<'llvm>], ret: Type<'llvm>) -> Type<'llvm> { + let type_ref = unsafe { + LLVMFunctionType( + ret, + args.as_mut_ptr(), + args.len() as libc::c_uint, + 0, /* IsVarArg */ + ) + }; + Type::new(type_ref) + } + + /// Add a function with the given `name` and `fn_type` to the module and return a value + /// reference representing the function. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or `name` could not be converted to a + /// [`SmallCStr`]. + pub fn add_fn(&'llvm self, name: &str, fn_type: Type<'llvm>) -> FnValue<'llvm> { + debug_assert_eq!( + fn_type.kind(), + LLVMTypeKind::LLVMFunctionTypeKind, + "Expected a function type when adding a function!" + ); + + let name = SmallCStr::try_from(name) + .expect("Failed to convert 'name' argument to small C string!"); + + let value_ref = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type.0) }; + FnValue::new(value_ref) + } + + /// Get a function value reference to the function with the given `name` if it was previously + /// added to the module with [`add_fn`][Module::add_fn]. + /// + /// # Panics + /// + /// Panics if `name` could not be converted to a [`SmallCStr`]. + pub fn get_fn(&'llvm self, name: &str) -> Option> { + let name = SmallCStr::try_from(name) + .expect("Failed to convert 'name' argument to small C string!"); + + let value_ref = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + + (!value_ref.is_null()).then(|| FnValue::new(value_ref)) + } + + /// Append a Basic Block to the end of the function referenced by the value reference + /// `fn_value`. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn append_basic_block(&'llvm self, fn_value: FnValue<'llvm>) -> BasicBlock<'llvm> { + let block = unsafe { + LLVMAppendBasicBlockInContext( + self.ctx, + fn_value.value_ref(), + b"block\0".as_ptr().cast(), + ) + }; + assert!(!block.is_null()); + + BasicBlock(block, PhantomData) + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { + LLVMDisposeModule(self.module); + LLVMContextDispose(self.ctx); + } + } +} + +// =========== +// Builder +// =========== + +/// Wrapper for a LLVM IR Builder. +pub struct Builder<'llvm> { + builder: LLVMBuilderRef, + _ctx: PhantomData<&'llvm ()>, +} + +impl<'llvm> Builder<'llvm> { + /// Create a new LLVM IR Builder with the `module`s context. + /// + /// # Panics + /// + /// Panics if creating the IR Builder fails. + pub fn with_ctx(module: &'llvm Module) -> Builder<'llvm> { + let builder = unsafe { LLVMCreateBuilderInContext(module.ctx) }; + assert!(!builder.is_null()); + + Builder { + builder, + _ctx: PhantomData, + } + } + + /// Position the IR Builder at the end of the given Basic Block. + pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) { + unsafe { + LLVMPositionBuilderAtEnd(self.builder, bb.0); + } + } + + /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fadd(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fadd: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fadd: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFAdd( + self.builder, + lhs.value_ref(), + rhs.value_ref(), + b"fadd\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [fsub](https://llvm.org/docs/LangRef.html#fsub-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fsub(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fsub: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fsub: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFSub( + self.builder, + lhs.value_ref(), + rhs.value_ref(), + b"fsub\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [fmul](https://llvm.org/docs/LangRef.html#fmul-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fmul(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fmul: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fmul: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFMul( + self.builder, + lhs.value_ref(), + rhs.value_ref(), + b"fmul\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [fcmult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFCmp( + self.builder, + LLVMRealPredicate::LLVMRealULT, + lhs.value_ref(), + rhs.value_ref(), + b"fcmplt\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn uitofp(&self, val: Value<'llvm>, dest_type: Type<'llvm>) -> Value<'llvm> { + debug_assert!(val.is_int(), "uitofp: Expected integer operand!"); + + let value_ref = unsafe { + LLVMBuildUIToFP( + self.builder, + val.value_ref(), + dest_type.0, + b"uitofp\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [call](https://llvm.org/docs/LangRef.html#call-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn call(&self, fn_value: FnValue<'llvm>, args: &mut [Value<'llvm>]) -> Value<'llvm> { + let value_ref = unsafe { + LLVMBuildCall2( + self.builder, + fn_value.ret_type(), + fn_value, + args.as_mut_ptr(), + args.len() as libc::c_uint, + b"call\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [ret](https://llvm.org/docs/LangRef.html#ret-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn ret(&self, ret: Value<'llvm>) { + let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) }; + assert!(!ret.is_null()); + } +} + +impl Drop for Builder<'_> { + fn drop(&mut self) { + unsafe { LLVMDisposeBuilder(self.builder) } + } +} + +// ============== +// BasicBlock +// ============== + +/// Wrapper for a LLVM Basic Block. +#[derive(Copy, Clone)] +pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); + +// ======== +// Type +// ======== + +/// Wrapper for a LLVM Type Reference. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Type<'llvm>(LLVMTypeRef, PhantomData<&'llvm ()>); + +impl<'llvm> Type<'llvm> { + fn new(type_ref: LLVMTypeRef) -> Self { + assert!(!type_ref.is_null()); + Type(type_ref, PhantomData) + } + + fn kind(&self) -> LLVMTypeKind { + unsafe { LLVMGetTypeKind(self.0) } + } + + /// Dump the LLVM Type to stdout. + pub fn dump(&self) { + unsafe { LLVMDumpType(self.0) }; + } + + /// Get a value reference representing the const `f64` value. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn const_f64(self, n: f64) -> Value<'llvm> { + debug_assert_eq!( + self.kind(), + LLVMTypeKind::LLVMDoubleTypeKind, + "Expected a double type when creating const f64 value!" + ); + + let value_ref = unsafe { LLVMConstReal(self.0, n) }; + Value::new(value_ref) + } +} + +// ========= +// Value +// ========= + +/// Wrapper for a LLVM Value Reference. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Value<'llvm>(LLVMValueRef, PhantomData<&'llvm ()>); + +impl<'llvm> Value<'llvm> { + fn new(value_ref: LLVMValueRef) -> Self { + assert!(!value_ref.is_null()); + Value(value_ref, PhantomData) + } + + #[inline] + fn value_ref(&self) -> LLVMValueRef { + self.0 + } + + fn kind(&self) -> LLVMValueKind { + unsafe { LLVMGetValueKind(self.value_ref()) } + } + + /// Dump the LLVM Value to stdout. + pub fn dump(&self) { + unsafe { LLVMDumpValue(self.value_ref()) }; + } + + /// Get a type reference representing for the given value reference. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn type_of(&self) -> Type<'llvm> { + let type_ref = unsafe { LLVMTypeOf(self.value_ref()) }; + Type::new(type_ref) + } + + /// Set the name for the given value reference. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn set_name(&self, name: &str) { + unsafe { LLVMSetValueName2(self.value_ref(), name.as_ptr().cast(), name.len()) }; + } + + /// Get the name for the given value reference. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn get_name(&self) -> &'llvm str { + let name = unsafe { + let mut len: libc::size_t = 0; + let name = LLVMGetValueName2(self.0, &mut len as _); + assert!(!name.is_null()); + + CStr::from_ptr(name) + }; + + // TODO: Does this string live for the time of the LLVM context?! + name.to_str() + .expect("Expected valid UTF8 string from LLVM API") + } + + /// Check if value is of `f64` type. + pub fn is_f64(&self) -> bool { + self.type_of().kind() == LLVMTypeKind::LLVMDoubleTypeKind + } + + /// Check if value is of integer type. + pub fn is_int(&self) -> bool { + self.type_of().kind() == LLVMTypeKind::LLVMIntegerTypeKind + } +} + +/// Wrapper for a LLVM Value Reference specialized for contexts where function values are needed. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct FnValue<'llvm>(Value<'llvm>); + +impl<'llvm> Deref for FnValue<'llvm> { + type Target = Value<'llvm>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'llvm> FnValue<'llvm> { + fn new(value_ref: LLVMValueRef) -> Self { + let value = Value::new(value_ref); + debug_assert_eq!( + value.kind(), + LLVMValueKind::LLVMFunctionValueKind, + "Expected a fn value when constructing FnValue!" + ); + + FnValue(value) + } + + /// Get a type reference representing the return value of the given function value. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn ret_type(&self) -> Type<'llvm> { + let type_ref = unsafe { LLVMGetReturnType(LLVMTypeOf(self.value_ref())) }; + Type::new(type_ref) + } + + /// Get the number of function arguments for the given function value. + pub fn args(&self) -> usize { + unsafe { LLVMCountParams(self.value_ref()) as usize } + } + + /// Get a value reference for the function argument at index `idx`. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or indexed out of bounds. + pub fn arg(&self, idx: usize) -> Value<'llvm> { + assert!(idx < self.args()); + + let value_ref = unsafe { LLVMGetParam(self.value_ref(), idx as libc::c_uint) }; + Value::new(value_ref) + } + + /// Get the number of Basic Blocks for the given function value. + pub fn basic_blocks(&self) -> usize { + unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize } + } + + /// Verify that the given function is valid. + pub fn verify(&self) -> bool { + unsafe { + LLVMVerifyFunction( + self.value_ref(), + LLVMVerifierFailureAction::LLVMPrintMessageAction, + ) == 0 + } + } +} diff --git a/src/main.rs b/src/main.rs index 2646873..7160e04 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,24 @@ -mod lexer; -mod parser; +use llvm_kaleidoscope_rs::{ + codegen::Codegen, + lexer::{Lexer, Token}, + llvm, + parser::Parser, + Either, +}; -use lexer::{Lexer, Token}; -use parser::Parser; use std::io::Read; -fn handle_definition(p: &mut Parser) +fn handle_definition(p: &mut Parser, module: &llvm::Module) where I: Iterator, { match p.parse_definition() { - Ok(expr) => println!("Parse 'def'\n{:?}", expr), + Ok(func) => { + println!("Parse 'def'\n{:?}", func); + if let Ok(func) = Codegen::compile(module, Either::B(&func)) { + func.dump(); + } + } Err(err) => { eprintln!("Error: {:?}", err); p.get_next_token(); @@ -18,12 +26,17 @@ where } } -fn handle_extern(p: &mut Parser) +fn handle_extern(p: &mut Parser, module: &llvm::Module) where I: Iterator, { match p.parse_extern() { - Ok(expr) => println!("Parse 'extern'\n{:?}", expr), + Ok(proto) => { + println!("Parse 'extern'\n{:?}", proto); + if let Ok(proto) = Codegen::compile(module, Either::A(&proto)) { + proto.dump(); + } + } Err(err) => { eprintln!("Error: {:?}", err); p.get_next_token(); @@ -31,12 +44,17 @@ where } } -fn handle_top_level_expression(p: &mut Parser) +fn handle_top_level_expression(p: &mut Parser, module: &llvm::Module) where I: Iterator, { match p.parse_top_level_expr() { - Ok(expr) => println!("Parse top-level expression\n{:?}", expr), + Ok(func) => { + println!("Parse top-level expression\n{:?}", func); + if let Ok(func) = Codegen::compile(module, Either::B(&func)) { + func.dump(); + } + } Err(err) => { eprintln!("Error: {:?}", err); p.get_next_token(); @@ -49,16 +67,22 @@ fn main() { println!("ENTER to parse current input."); println!("C-d to exit."); + // Create lexer over stdin. let lexer = Lexer::new(std::io::stdin().bytes().filter_map(|v| { let v = v.ok()?; Some(v.into()) })); + // Create parser for kaleidoscope. let mut parser = Parser::new(lexer); // Throw first coin and initialize cur_tok. parser.get_next_token(); + // Initialize LLVM module with its own context. + // We will emit LLVM IR into this module. + let module = llvm::Module::new(); + loop { match *parser.cur_tok() { Token::Eof => break, @@ -66,9 +90,12 @@ fn main() { // Ignore top-level semicolon. parser.get_next_token() } - Token::Def => handle_definition(&mut parser), - Token::Extern => handle_extern(&mut parser), - _ => handle_top_level_expression(&mut parser), + Token::Def => handle_definition(&mut parser, &module), + Token::Extern => handle_extern(&mut parser, &module), + _ => handle_top_level_expression(&mut parser, &module), } } + + // Dump all the emitted LLVM IR to stdout. + module.dump(); } diff --git a/src/parser.rs b/src/parser.rs index 63f5a77..af69a87 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -19,15 +19,16 @@ pub enum ExprAST { /// which captures its name, and its argument names (thus implicitly the number /// of arguments the function takes). #[derive(Debug, PartialEq)] -pub struct PrototypeAST(String, Vec); +pub struct PrototypeAST(pub String, pub Vec); /// FunctionAST - This class represents a function definition itself. #[derive(Debug, PartialEq)] -pub struct FunctionAST(PrototypeAST, ExprAST); +pub struct FunctionAST(pub PrototypeAST, pub ExprAST); /// Parse result with String as Error type (to be compliant with tutorial). type ParseResult = Result; +/// Parser for the `kaleidoscope` language. pub struct Parser where I: Iterator, -- cgit v1.2.3