diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/codegen.rs | 9 | ||||
-rw-r--r-- | src/llvm.rs | 603 | ||||
-rw-r--r-- | src/llvm/builder.rs | 189 | ||||
-rw-r--r-- | src/llvm/mod.rs | 29 | ||||
-rw-r--r-- | src/llvm/module.rs | 157 | ||||
-rw-r--r-- | src/llvm/pass_manager.rs | 75 | ||||
-rw-r--r-- | src/llvm/type_.rs | 58 | ||||
-rw-r--r-- | src/llvm/value.rs | 168 |
8 files changed, 682 insertions, 606 deletions
diff --git a/src/codegen.rs b/src/codegen.rs index 0a7893e..08c3039 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::llvm::{Builder, FnValue, FunctionPassManager, Module, Value}; +use crate::llvm::{FnValue, FunctionPassManager, IRBuilder, Module, Value}; use crate::parser::{ExprAST, FunctionAST, PrototypeAST}; use crate::Either; @@ -9,7 +9,7 @@ type CodegenResult<T> = Result<T, String>; /// Code generator from kaleidoscope AST to LLVM IR. pub struct Codegen<'llvm, 'a> { module: &'llvm Module, - builder: &'a Builder<'llvm>, + builder: &'a IRBuilder<'llvm>, fpm: &'a FunctionPassManager<'llvm>, } @@ -21,7 +21,7 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { ) -> CodegenResult<FnValue<'llvm>> { let cg = Codegen { module, - builder: &Builder::with_ctx(module), + builder: &IRBuilder::with_ctx(module), fpm: &FunctionPassManager::with_ctx(module), }; let mut variables = HashMap::new(); @@ -129,7 +129,10 @@ impl<'llvm, 'a> Codegen<'llvm, 'a> { if let Ok(ret) = self.codegen_expr(body, named_values) { self.builder.ret(ret); assert!(the_function.verify()); + + // Run the optimization passes on the function. self.fpm.run(the_function); + Ok(the_function) } else { todo!("Failed to codegen function body, erase from module!"); diff --git a/src/llvm.rs b/src/llvm.rs deleted file mode 100644 index 5598cd4..0000000 --- a/src/llvm.rs +++ /dev/null @@ -1,603 +0,0 @@ -//! Safe wrapper around the LLVM C API. -//! -//! References returned from the LLVM API are tied to the `'llvm` lifetime which is bound to the -//! context where the objects are created in. -//! We do not offer wrappers to remove or delete any objects in the context and therefore all the -//! references will be valid for the liftime of the context. - -use llvm_sys::analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction}; -use llvm_sys::core::{ - LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, - LLVMBuildFSub, LLVMBuildRet, LLVMBuildUIToFP, LLVMConstReal, LLVMContextCreate, - LLVMContextDispose, LLVMCountBasicBlocks, LLVMCountParams, LLVMCreateBuilderInContext, - LLVMCreateFunctionPassManagerForModule, LLVMDisposeBuilder, LLVMDisposeModule, - LLVMDisposePassManager, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMDumpType, LLVMDumpValue, - LLVMGetNamedFunction, LLVMGetParam, LLVMGetReturnType, LLVMGetTypeKind, LLVMGetValueKind, - LLVMGetValueName2, LLVMInitializeFunctionPassManager, LLVMModuleCreateWithNameInContext, - LLVMPositionBuilderAtEnd, LLVMRunFunctionPassManager, LLVMSetValueName2, LLVMTypeOf, -}; -use llvm_sys::prelude::{ - LLVMBasicBlockRef, LLVMBool, LLVMBuilderRef, LLVMContextRef, LLVMModuleRef, LLVMPassManagerRef, - LLVMTypeRef, LLVMValueRef, -}; -use llvm_sys::transforms::{ - instcombine::LLVMAddInstructionCombiningPass, - scalar::{LLVMAddCFGSimplificationPass, LLVMAddNewGVNPass, LLVMAddReassociatePass}, -}; -use llvm_sys::{LLVMRealPredicate, LLVMTypeKind, LLVMValueKind}; - -use std::convert::TryFrom; -use std::ffi::CStr; -use std::marker::PhantomData; -use std::ops::Deref; - -use crate::SmallCStr; - -// Definition of LLVM C API functions using our `repr(transparent)` types. -extern "C" { - fn LLVMFunctionType( - ReturnType: Type<'_>, - ParamTypes: *mut Type<'_>, - ParamCount: ::libc::c_uint, - IsVarArg: LLVMBool, - ) -> LLVMTypeRef; - fn LLVMBuildCall2( - arg1: LLVMBuilderRef, - arg2: Type<'_>, - Fn: FnValue<'_>, - Args: *mut Value<'_>, - NumArgs: ::libc::c_uint, - Name: *const ::libc::c_char, - ) -> LLVMValueRef; -} - -// ==================== -// Module / Context -// ==================== - -/// Wrapper for a LLVM Module with its own LLVM Context. -pub struct Module { - ctx: LLVMContextRef, - module: LLVMModuleRef, -} - -impl<'llvm> Module { - /// Create a new Module instance. - /// - /// # Panics - /// - /// Panics if creating the context or the module fails. - pub fn new() -> Self { - let (ctx, module) = unsafe { - let c = LLVMContextCreate(); - let m = LLVMModuleCreateWithNameInContext(b"module\0".as_ptr().cast(), c); - assert!(!c.is_null() && !m.is_null()); - (c, m) - }; - - Module { ctx, module } - } - - /// Dump LLVM IR emitted into the Module to stdout. - pub fn dump(&self) { - unsafe { LLVMDumpModule(self.module) }; - } - - /// Get a type reference representing a `f64` float. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn type_f64(&self) -> Type<'llvm> { - let type_ref = unsafe { LLVMDoubleTypeInContext(self.ctx) }; - Type::new(type_ref) - } - - /// Get a type reference representing a `fn(args) -> ret` function. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn type_fn(&'llvm self, args: &mut [Type<'llvm>], ret: Type<'llvm>) -> Type<'llvm> { - let type_ref = unsafe { - LLVMFunctionType( - ret, - args.as_mut_ptr(), - args.len() as libc::c_uint, - 0, /* IsVarArg */ - ) - }; - Type::new(type_ref) - } - - /// Add a function with the given `name` and `fn_type` to the module and return a value - /// reference representing the function. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer or `name` could not be converted to a - /// [`SmallCStr`]. - pub fn add_fn(&'llvm self, name: &str, fn_type: Type<'llvm>) -> FnValue<'llvm> { - debug_assert_eq!( - fn_type.kind(), - LLVMTypeKind::LLVMFunctionTypeKind, - "Expected a function type when adding a function!" - ); - - let name = SmallCStr::try_from(name) - .expect("Failed to convert 'name' argument to small C string!"); - - let value_ref = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type.0) }; - FnValue::new(value_ref) - } - - /// Get a function value reference to the function with the given `name` if it was previously - /// added to the module with [`add_fn`][Module::add_fn]. - /// - /// # Panics - /// - /// Panics if `name` could not be converted to a [`SmallCStr`]. - pub fn get_fn(&'llvm self, name: &str) -> Option<FnValue<'llvm>> { - let name = SmallCStr::try_from(name) - .expect("Failed to convert 'name' argument to small C string!"); - - let value_ref = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; - - (!value_ref.is_null()).then(|| FnValue::new(value_ref)) - } - - /// Append a Basic Block to the end of the function referenced by the value reference - /// `fn_value`. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn append_basic_block(&'llvm self, fn_value: FnValue<'llvm>) -> BasicBlock<'llvm> { - let block = unsafe { - LLVMAppendBasicBlockInContext( - self.ctx, - fn_value.value_ref(), - b"block\0".as_ptr().cast(), - ) - }; - assert!(!block.is_null()); - - BasicBlock(block, PhantomData) - } -} - -impl Drop for Module { - fn drop(&mut self) { - unsafe { - LLVMDisposeModule(self.module); - LLVMContextDispose(self.ctx); - } - } -} - -// =========== -// Builder -// =========== - -/// Wrapper for a LLVM IR Builder. -pub struct Builder<'llvm> { - builder: LLVMBuilderRef, - _ctx: PhantomData<&'llvm ()>, -} - -impl<'llvm> Builder<'llvm> { - /// Create a new LLVM IR Builder with the `module`s context. - /// - /// # Panics - /// - /// Panics if creating the IR Builder fails. - pub fn with_ctx(module: &'llvm Module) -> Builder<'llvm> { - let builder = unsafe { LLVMCreateBuilderInContext(module.ctx) }; - assert!(!builder.is_null()); - - Builder { - builder, - _ctx: PhantomData, - } - } - - /// Position the IR Builder at the end of the given Basic Block. - pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) { - unsafe { - LLVMPositionBuilderAtEnd(self.builder, bb.0); - } - } - - /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn fadd(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { - debug_assert!(lhs.is_f64(), "fadd: Expected f64 as lhs operand!"); - debug_assert!(rhs.is_f64(), "fadd: Expected f64 as rhs operand!"); - - let value_ref = unsafe { - LLVMBuildFAdd( - self.builder, - lhs.value_ref(), - rhs.value_ref(), - b"fadd\0".as_ptr().cast(), - ) - }; - Value::new(value_ref) - } - - /// Emit a [fsub](https://llvm.org/docs/LangRef.html#fsub-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn fsub(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { - debug_assert!(lhs.is_f64(), "fsub: Expected f64 as lhs operand!"); - debug_assert!(rhs.is_f64(), "fsub: Expected f64 as rhs operand!"); - - let value_ref = unsafe { - LLVMBuildFSub( - self.builder, - lhs.value_ref(), - rhs.value_ref(), - b"fsub\0".as_ptr().cast(), - ) - }; - Value::new(value_ref) - } - - /// Emit a [fmul](https://llvm.org/docs/LangRef.html#fmul-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn fmul(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { - debug_assert!(lhs.is_f64(), "fmul: Expected f64 as lhs operand!"); - debug_assert!(rhs.is_f64(), "fmul: Expected f64 as rhs operand!"); - - let value_ref = unsafe { - LLVMBuildFMul( - self.builder, - lhs.value_ref(), - rhs.value_ref(), - b"fmul\0".as_ptr().cast(), - ) - }; - Value::new(value_ref) - } - - /// Emit a [fcmult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { - debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!"); - debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!"); - - let value_ref = unsafe { - LLVMBuildFCmp( - self.builder, - LLVMRealPredicate::LLVMRealULT, - lhs.value_ref(), - rhs.value_ref(), - b"fcmplt\0".as_ptr().cast(), - ) - }; - Value::new(value_ref) - } - - /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn uitofp(&self, val: Value<'llvm>, dest_type: Type<'llvm>) -> Value<'llvm> { - debug_assert!(val.is_int(), "uitofp: Expected integer operand!"); - - let value_ref = unsafe { - LLVMBuildUIToFP( - self.builder, - val.value_ref(), - dest_type.0, - b"uitofp\0".as_ptr().cast(), - ) - }; - Value::new(value_ref) - } - - /// Emit a [call](https://llvm.org/docs/LangRef.html#call-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn call(&self, fn_value: FnValue<'llvm>, args: &mut [Value<'llvm>]) -> Value<'llvm> { - let value_ref = unsafe { - LLVMBuildCall2( - self.builder, - fn_value.ret_type(), - fn_value, - args.as_mut_ptr(), - args.len() as libc::c_uint, - b"call\0".as_ptr().cast(), - ) - }; - Value::new(value_ref) - } - - /// Emit a [ret](https://llvm.org/docs/LangRef.html#ret-instruction) instruction. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn ret(&self, ret: Value<'llvm>) { - let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) }; - assert!(!ret.is_null()); - } -} - -impl Drop for Builder<'_> { - fn drop(&mut self) { - unsafe { LLVMDisposeBuilder(self.builder) } - } -} - -// ============== -// BasicBlock -// ============== - -/// Wrapper for a LLVM Basic Block. -#[derive(Copy, Clone)] -pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); - -// ======== -// Type -// ======== - -/// Wrapper for a LLVM Type Reference. -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct Type<'llvm>(LLVMTypeRef, PhantomData<&'llvm ()>); - -impl<'llvm> Type<'llvm> { - fn new(type_ref: LLVMTypeRef) -> Self { - assert!(!type_ref.is_null()); - Type(type_ref, PhantomData) - } - - fn kind(&self) -> LLVMTypeKind { - unsafe { LLVMGetTypeKind(self.0) } - } - - /// Dump the LLVM Type to stdout. - pub fn dump(&self) { - unsafe { LLVMDumpType(self.0) }; - } - - /// Get a value reference representing the const `f64` value. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn const_f64(self, n: f64) -> Value<'llvm> { - debug_assert_eq!( - self.kind(), - LLVMTypeKind::LLVMDoubleTypeKind, - "Expected a double type when creating const f64 value!" - ); - - let value_ref = unsafe { LLVMConstReal(self.0, n) }; - Value::new(value_ref) - } -} - -// ========= -// Value -// ========= - -/// Wrapper for a LLVM Value Reference. -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct Value<'llvm>(LLVMValueRef, PhantomData<&'llvm ()>); - -impl<'llvm> Value<'llvm> { - fn new(value_ref: LLVMValueRef) -> Self { - assert!(!value_ref.is_null()); - Value(value_ref, PhantomData) - } - - #[inline] - fn value_ref(&self) -> LLVMValueRef { - self.0 - } - - fn kind(&self) -> LLVMValueKind { - unsafe { LLVMGetValueKind(self.value_ref()) } - } - - /// Dump the LLVM Value to stdout. - pub fn dump(&self) { - unsafe { LLVMDumpValue(self.value_ref()) }; - } - - /// Get a type reference representing for the given value reference. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn type_of(&self) -> Type<'llvm> { - let type_ref = unsafe { LLVMTypeOf(self.value_ref()) }; - Type::new(type_ref) - } - - /// Set the name for the given value reference. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn set_name(&self, name: &str) { - unsafe { LLVMSetValueName2(self.value_ref(), name.as_ptr().cast(), name.len()) }; - } - - /// Get the name for the given value reference. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn get_name(&self) -> &'llvm str { - let name = unsafe { - let mut len: libc::size_t = 0; - let name = LLVMGetValueName2(self.0, &mut len as _); - assert!(!name.is_null()); - - CStr::from_ptr(name) - }; - - // TODO: Does this string live for the time of the LLVM context?! - name.to_str() - .expect("Expected valid UTF8 string from LLVM API") - } - - /// Check if value is of `f64` type. - pub fn is_f64(&self) -> bool { - self.type_of().kind() == LLVMTypeKind::LLVMDoubleTypeKind - } - - /// Check if value is of integer type. - pub fn is_int(&self) -> bool { - self.type_of().kind() == LLVMTypeKind::LLVMIntegerTypeKind - } -} - -/// Wrapper for a LLVM Value Reference specialized for contexts where function values are needed. -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct FnValue<'llvm>(Value<'llvm>); - -impl<'llvm> Deref for FnValue<'llvm> { - type Target = Value<'llvm>; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<'llvm> FnValue<'llvm> { - fn new(value_ref: LLVMValueRef) -> Self { - let value = Value::new(value_ref); - debug_assert_eq!( - value.kind(), - LLVMValueKind::LLVMFunctionValueKind, - "Expected a fn value when constructing FnValue!" - ); - - FnValue(value) - } - - /// Get a type reference representing the return value of the given function value. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer. - pub fn ret_type(&self) -> Type<'llvm> { - let type_ref = unsafe { LLVMGetReturnType(LLVMTypeOf(self.value_ref())) }; - Type::new(type_ref) - } - - /// Get the number of function arguments for the given function value. - pub fn args(&self) -> usize { - unsafe { LLVMCountParams(self.value_ref()) as usize } - } - - /// Get a value reference for the function argument at index `idx`. - /// - /// # Panics - /// - /// Panics if LLVM API returns a `null` pointer or indexed out of bounds. - pub fn arg(&self, idx: usize) -> Value<'llvm> { - assert!(idx < self.args()); - - let value_ref = unsafe { LLVMGetParam(self.value_ref(), idx as libc::c_uint) }; - Value::new(value_ref) - } - - /// Get the number of Basic Blocks for the given function value. - pub fn basic_blocks(&self) -> usize { - unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize } - } - - /// Verify that the given function is valid. - pub fn verify(&self) -> bool { - unsafe { - LLVMVerifyFunction( - self.value_ref(), - LLVMVerifierFailureAction::LLVMPrintMessageAction, - ) == 0 - } - } -} - -// ======================= -// FunctionPassManager -// ======================= - -/// Wrapper for a LLVM Function PassManager (legacy). -pub struct FunctionPassManager<'llvm> { - fpm: LLVMPassManagerRef, - _ctx: PhantomData<&'llvm ()>, -} - -impl<'llvm> FunctionPassManager<'llvm> { - /// Create a new Function PassManager with the following optimization passes - /// - InstructionCombiningPass - /// - ReassociatePass - /// - NewGVNPass - /// - CFGSimplificationPass - /// - /// The list of selected optimization passes is taken from the tutorial chapter [LLVM - /// Optimization Passes](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl04.html#id3). - pub fn with_ctx(module: &'llvm Module) -> FunctionPassManager<'llvm> { - let fpm = unsafe { - // Borrows module reference. - LLVMCreateFunctionPassManagerForModule(module.module) - }; - assert!(!fpm.is_null()); - - unsafe { - // Do simple "peephole" optimizations and bit-twiddling optzns. - LLVMAddInstructionCombiningPass(fpm); - // Reassociate expressions. - LLVMAddReassociatePass(fpm); - // Eliminate Common SubExpressions. - LLVMAddNewGVNPass(fpm); - // Simplify the control flow graph (deleting unreachable blocks, etc). - LLVMAddCFGSimplificationPass(fpm); - - let fail = LLVMInitializeFunctionPassManager(fpm); - assert_eq!(fail, 0); - } - - FunctionPassManager { - fpm, - _ctx: PhantomData, - } - } - - /// Run the optimization passes registered with the Function PassManager on the function - /// referenced by `fn_value`. - pub fn run(&'llvm self, fn_value: FnValue<'llvm>) { - unsafe { - // Returns 1 if any of the passes modified the function, false otherwise. - LLVMRunFunctionPassManager(self.fpm, fn_value.value_ref()); - } - } -} - -impl Drop for FunctionPassManager<'_> { - fn drop(&mut self) { - unsafe { - LLVMDisposePassManager(self.fpm); - } - } -} diff --git a/src/llvm/builder.rs b/src/llvm/builder.rs new file mode 100644 index 0000000..3d43b68 --- /dev/null +++ b/src/llvm/builder.rs @@ -0,0 +1,189 @@ +use llvm_sys::{ + core::{ + LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, LLVMBuildFSub, LLVMBuildRet, LLVMBuildUIToFP, + LLVMCreateBuilderInContext, LLVMDisposeBuilder, LLVMPositionBuilderAtEnd, + }, + prelude::{LLVMBuilderRef, LLVMValueRef}, + LLVMRealPredicate, +}; + +use std::marker::PhantomData; + +use super::{BasicBlock, FnValue, Module, Type, Value}; + +// Definition of LLVM C API functions using our `repr(transparent)` types. +extern "C" { + fn LLVMBuildCall2( + arg1: LLVMBuilderRef, + arg2: Type<'_>, + Fn: FnValue<'_>, + Args: *mut Value<'_>, + NumArgs: ::libc::c_uint, + Name: *const ::libc::c_char, + ) -> LLVMValueRef; +} + +/// Wrapper for a LLVM IR Builder. +pub struct IRBuilder<'llvm> { + builder: LLVMBuilderRef, + _ctx: PhantomData<&'llvm ()>, +} + +impl<'llvm> IRBuilder<'llvm> { + /// Create a new LLVM IR Builder with the `module`s context. + /// + /// # Panics + /// + /// Panics if creating the IR Builder fails. + pub fn with_ctx(module: &'llvm Module) -> IRBuilder<'llvm> { + let builder = unsafe { LLVMCreateBuilderInContext(module.ctx()) }; + assert!(!builder.is_null()); + + IRBuilder { + builder, + _ctx: PhantomData, + } + } + + /// Position the IR Builder at the end of the given Basic Block. + pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) { + unsafe { + LLVMPositionBuilderAtEnd(self.builder, bb.0); + } + } + + /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fadd(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fadd: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fadd: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFAdd( + self.builder, + lhs.value_ref(), + rhs.value_ref(), + b"fadd\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [fsub](https://llvm.org/docs/LangRef.html#fsub-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fsub(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fsub: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fsub: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFSub( + self.builder, + lhs.value_ref(), + rhs.value_ref(), + b"fsub\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [fmul](https://llvm.org/docs/LangRef.html#fmul-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fmul(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fmul: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fmul: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFMul( + self.builder, + lhs.value_ref(), + rhs.value_ref(), + b"fmul\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [fcmult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { + debug_assert!(lhs.is_f64(), "fcmplt: Expected f64 as lhs operand!"); + debug_assert!(rhs.is_f64(), "fcmplt: Expected f64 as rhs operand!"); + + let value_ref = unsafe { + LLVMBuildFCmp( + self.builder, + LLVMRealPredicate::LLVMRealULT, + lhs.value_ref(), + rhs.value_ref(), + b"fcmplt\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn uitofp(&self, val: Value<'llvm>, dest_type: Type<'llvm>) -> Value<'llvm> { + debug_assert!(val.is_int(), "uitofp: Expected integer operand!"); + + let value_ref = unsafe { + LLVMBuildUIToFP( + self.builder, + val.value_ref(), + dest_type.type_ref(), + b"uitofp\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [call](https://llvm.org/docs/LangRef.html#call-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn call(&self, fn_value: FnValue<'llvm>, args: &mut [Value<'llvm>]) -> Value<'llvm> { + let value_ref = unsafe { + LLVMBuildCall2( + self.builder, + fn_value.ret_type(), + fn_value, + args.as_mut_ptr(), + args.len() as libc::c_uint, + b"call\0".as_ptr().cast(), + ) + }; + Value::new(value_ref) + } + + /// Emit a [ret](https://llvm.org/docs/LangRef.html#ret-instruction) instruction. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn ret(&self, ret: Value<'llvm>) { + let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) }; + assert!(!ret.is_null()); + } +} + +impl Drop for IRBuilder<'_> { + fn drop(&mut self) { + unsafe { LLVMDisposeBuilder(self.builder) } + } +} diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs new file mode 100644 index 0000000..f3e54a8 --- /dev/null +++ b/src/llvm/mod.rs @@ -0,0 +1,29 @@ +//! Safe wrapper around the LLVM C API. +//! +//! References returned from the LLVM API are tied to the `'llvm` lifetime which is bound to the +//! context where the objects are created in. +//! We do not offer wrappers to remove or delete any objects in the context and therefore all the +//! references will be valid for the liftime of the context. +//! +//! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM +//! API calls. + +use llvm_sys::prelude::LLVMBasicBlockRef; + +use std::marker::PhantomData; + +mod builder; +mod module; +mod pass_manager; +mod type_; +mod value; + +pub use builder::IRBuilder; +pub use module::Module; +pub use pass_manager::FunctionPassManager; +pub use type_::Type; +pub use value::{FnValue, Value}; + +/// Wrapper for a LLVM Basic Block. +#[derive(Copy, Clone)] +pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); diff --git a/src/llvm/module.rs b/src/llvm/module.rs new file mode 100644 index 0000000..e0aad96 --- /dev/null +++ b/src/llvm/module.rs @@ -0,0 +1,157 @@ +use llvm_sys::{ + core::{ + LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMContextCreate, LLVMContextDispose, + LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMGetNamedFunction, + LLVMModuleCreateWithNameInContext, + }, + prelude::{LLVMBool, LLVMContextRef, LLVMModuleRef, LLVMTypeRef}, + LLVMTypeKind, +}; + +use std::convert::TryFrom; +use std::marker::PhantomData; + +use super::{BasicBlock, FnValue, Type}; +use crate::SmallCStr; + +// Definition of LLVM C API functions using our `repr(transparent)` types. +extern "C" { + fn LLVMFunctionType( + ReturnType: Type<'_>, + ParamTypes: *mut Type<'_>, + ParamCount: ::libc::c_uint, + IsVarArg: LLVMBool, + ) -> LLVMTypeRef; +} + +/// Wrapper for a LLVM Module with its own LLVM Context. +pub struct Module { + ctx: LLVMContextRef, + module: LLVMModuleRef, +} + +impl<'llvm> Module { + /// Create a new Module instance. + /// + /// # Panics + /// + /// Panics if creating the context or the module fails. + pub fn new() -> Self { + let (ctx, module) = unsafe { + let c = LLVMContextCreate(); + let m = LLVMModuleCreateWithNameInContext(b"module\0".as_ptr().cast(), c); + assert!(!c.is_null() && !m.is_null()); + (c, m) + }; + + Module { ctx, module } + } + + /// Get the raw LLVM context reference. + #[inline] + pub(super) fn ctx(&self) -> LLVMContextRef { + self.ctx + } + + /// Get the raw LLVM module reference. + #[inline] + pub(super) fn module(&self) -> LLVMModuleRef { + self.module + } + + /// Dump LLVM IR emitted into the Module to stdout. + pub fn dump(&self) { + unsafe { LLVMDumpModule(self.module) }; + } + + /// Get a type reference representing a `f64` float. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn type_f64(&self) -> Type<'llvm> { + let type_ref = unsafe { LLVMDoubleTypeInContext(self.ctx) }; + Type::new(type_ref) + } + + /// Get a type reference representing a `fn(args) -> ret` function. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn type_fn(&'llvm self, args: &mut [Type<'llvm>], ret: Type<'llvm>) -> Type<'llvm> { + let type_ref = unsafe { + LLVMFunctionType( + ret, + args.as_mut_ptr(), + args.len() as libc::c_uint, + 0, /* IsVarArg */ + ) + }; + Type::new(type_ref) + } + + /// Add a function with the given `name` and `fn_type` to the module and return a value + /// reference representing the function. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or `name` could not be converted to a + /// [`SmallCStr`]. + pub fn add_fn(&'llvm self, name: &str, fn_type: Type<'llvm>) -> FnValue<'llvm> { + debug_assert_eq!( + fn_type.kind(), + LLVMTypeKind::LLVMFunctionTypeKind, + "Expected a function type when adding a function!" + ); + + let name = SmallCStr::try_from(name) + .expect("Failed to convert 'name' argument to small C string!"); + + let value_ref = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type.type_ref()) }; + FnValue::new(value_ref) + } + + /// Get a function value reference to the function with the given `name` if it was previously + /// added to the module with [`add_fn`][Module::add_fn]. + /// + /// # Panics + /// + /// Panics if `name` could not be converted to a [`SmallCStr`]. + pub fn get_fn(&'llvm self, name: &str) -> Option<FnValue<'llvm>> { + let name = SmallCStr::try_from(name) + .expect("Failed to convert 'name' argument to small C string!"); + + let value_ref = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + + (!value_ref.is_null()).then(|| FnValue::new(value_ref)) + } + + /// Append a Basic Block to the end of the function referenced by the value reference + /// `fn_value`. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn append_basic_block(&'llvm self, fn_value: FnValue<'llvm>) -> BasicBlock<'llvm> { + let block = unsafe { + LLVMAppendBasicBlockInContext( + self.ctx, + fn_value.value_ref(), + b"block\0".as_ptr().cast(), + ) + }; + assert!(!block.is_null()); + + BasicBlock(block, PhantomData) + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { + LLVMDisposeModule(self.module); + LLVMContextDispose(self.ctx); + } + } +} diff --git a/src/llvm/pass_manager.rs b/src/llvm/pass_manager.rs new file mode 100644 index 0000000..1551875 --- /dev/null +++ b/src/llvm/pass_manager.rs @@ -0,0 +1,75 @@ +use llvm_sys::{ + core::{ + LLVMCreateFunctionPassManagerForModule, LLVMDisposePassManager, + LLVMInitializeFunctionPassManager, LLVMRunFunctionPassManager, + }, + prelude::LLVMPassManagerRef, + transforms::{ + instcombine::LLVMAddInstructionCombiningPass, + scalar::{LLVMAddCFGSimplificationPass, LLVMAddNewGVNPass, LLVMAddReassociatePass}, + }, +}; + +use std::marker::PhantomData; + +use super::{FnValue, Module}; + +/// Wrapper for a LLVM Function PassManager (legacy). +pub struct FunctionPassManager<'llvm> { + fpm: LLVMPassManagerRef, + _ctx: PhantomData<&'llvm ()>, +} + +impl<'llvm> FunctionPassManager<'llvm> { + /// Create a new Function PassManager with the following optimization passes + /// - InstructionCombiningPass + /// - ReassociatePass + /// - NewGVNPass + /// - CFGSimplificationPass + /// + /// The list of selected optimization passes is taken from the tutorial chapter [LLVM + /// Optimization Passes](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl04.html#id3). + pub fn with_ctx(module: &'llvm Module) -> FunctionPassManager<'llvm> { + let fpm = unsafe { + // Borrows module reference. + LLVMCreateFunctionPassManagerForModule(module.module()) + }; + assert!(!fpm.is_null()); + + unsafe { + // Do simple "peephole" optimizations and bit-twiddling optzns. + LLVMAddInstructionCombiningPass(fpm); + // Reassociate expressions. + LLVMAddReassociatePass(fpm); + // Eliminate Common SubExpressions. + LLVMAddNewGVNPass(fpm); + // Simplify the control flow graph (deleting unreachable blocks, etc). + LLVMAddCFGSimplificationPass(fpm); + + let fail = LLVMInitializeFunctionPassManager(fpm); + assert_eq!(fail, 0); + } + + FunctionPassManager { + fpm, + _ctx: PhantomData, + } + } + + /// Run the optimization passes registered with the Function PassManager on the function + /// referenced by `fn_value`. + pub fn run(&'llvm self, fn_value: FnValue<'llvm>) { + unsafe { + // Returns 1 if any of the passes modified the function, false otherwise. + LLVMRunFunctionPassManager(self.fpm, fn_value.value_ref()); + } + } +} + +impl Drop for FunctionPassManager<'_> { + fn drop(&mut self) { + unsafe { + LLVMDisposePassManager(self.fpm); + } + } +} diff --git a/src/llvm/type_.rs b/src/llvm/type_.rs new file mode 100644 index 0000000..8668c7c --- /dev/null +++ b/src/llvm/type_.rs @@ -0,0 +1,58 @@ +use llvm_sys::{ + core::{LLVMConstReal, LLVMDumpType, LLVMGetTypeKind}, + prelude::LLVMTypeRef, + LLVMTypeKind, +}; + +use std::marker::PhantomData; + +use super::Value; + +/// Wrapper for a LLVM Type Reference. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Type<'llvm>(LLVMTypeRef, PhantomData<&'llvm ()>); + +impl<'llvm> Type<'llvm> { + /// Create a new Type instance. + /// + /// # Panics + /// + /// Panics if `type_ref` is a null pointer. + pub(super) fn new(type_ref: LLVMTypeRef) -> Self { + assert!(!type_ref.is_null()); + Type(type_ref, PhantomData) + } + + /// Get the raw LLVM type reference. + #[inline] + pub(super) fn type_ref(&self) -> LLVMTypeRef { + self.0 + } + + /// Get the LLVM type kind for the given type reference. + pub(super) fn kind(&self) -> LLVMTypeKind { + unsafe { LLVMGetTypeKind(self.type_ref()) } + } + + /// Dump the LLVM Type to stdout. + pub fn dump(&self) { + unsafe { LLVMDumpType(self.type_ref()) }; + } + + /// Get a value reference representing the const `f64` value. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn const_f64(self, n: f64) -> Value<'llvm> { + debug_assert_eq!( + self.kind(), + LLVMTypeKind::LLVMDoubleTypeKind, + "Expected a double type when creating const f64 value!" + ); + + let value_ref = unsafe { LLVMConstReal(self.type_ref(), n) }; + Value::new(value_ref) + } +} diff --git a/src/llvm/value.rs b/src/llvm/value.rs new file mode 100644 index 0000000..9b79c69 --- /dev/null +++ b/src/llvm/value.rs @@ -0,0 +1,168 @@ +use llvm_sys::{ + analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction}, + core::{ + LLVMCountBasicBlocks, LLVMCountParams, LLVMDumpValue, LLVMGetParam, LLVMGetReturnType, + LLVMGetValueKind, LLVMGetValueName2, LLVMSetValueName2, LLVMTypeOf, + }, + prelude::LLVMValueRef, + LLVMTypeKind, LLVMValueKind, +}; + +use std::ffi::CStr; +use std::marker::PhantomData; +use std::ops::Deref; + +use super::Type; + +/// Wrapper for a LLVM Value Reference. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Value<'llvm>(LLVMValueRef, PhantomData<&'llvm ()>); + +impl<'llvm> Value<'llvm> { + /// Create a new Value instance. + /// + /// # Panics + /// + /// Panics if `value_ref` is a null pointer. + pub(super) fn new(value_ref: LLVMValueRef) -> Self { + assert!(!value_ref.is_null()); + Value(value_ref, PhantomData) + } + + /// Get the raw LLVM value reference. + #[inline] + pub(super) fn value_ref(&self) -> LLVMValueRef { + self.0 + } + + /// Get the LLVM value kind for the given value reference. + pub(super) fn kind(&self) -> LLVMValueKind { + unsafe { LLVMGetValueKind(self.value_ref()) } + } + + /// Dump the LLVM Value to stdout. + pub fn dump(&self) { + unsafe { LLVMDumpValue(self.value_ref()) }; + } + + /// Get a type reference representing for the given value reference. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn type_of(&self) -> Type<'llvm> { + let type_ref = unsafe { LLVMTypeOf(self.value_ref()) }; + Type::new(type_ref) + } + + /// Set the name for the given value reference. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn set_name(&self, name: &str) { + unsafe { LLVMSetValueName2(self.value_ref(), name.as_ptr().cast(), name.len()) }; + } + + /// Get the name for the given value reference. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn get_name(&self) -> &'llvm str { + let name = unsafe { + let mut len: libc::size_t = 0; + let name = LLVMGetValueName2(self.0, &mut len as _); + assert!(!name.is_null()); + + CStr::from_ptr(name) + }; + + // TODO: Does this string live for the time of the LLVM context?! + name.to_str() + .expect("Expected valid UTF8 string from LLVM API") + } + + /// Check if value is of `f64` type. + pub fn is_f64(&self) -> bool { + self.type_of().kind() == LLVMTypeKind::LLVMDoubleTypeKind + } + + /// Check if value is of integer type. + pub fn is_int(&self) -> bool { + self.type_of().kind() == LLVMTypeKind::LLVMIntegerTypeKind + } +} + +/// Wrapper for a LLVM Value Reference specialized for contexts where function values are needed. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct FnValue<'llvm>(Value<'llvm>); + +impl<'llvm> Deref for FnValue<'llvm> { + type Target = Value<'llvm>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'llvm> FnValue<'llvm> { + /// Create a new FnValue instance. + /// + /// # Panics + /// + /// Panics if `value_ref` is a null pointer. + pub(super) fn new(value_ref: LLVMValueRef) -> Self { + let value = Value::new(value_ref); + debug_assert_eq!( + value.kind(), + LLVMValueKind::LLVMFunctionValueKind, + "Expected a fn value when constructing FnValue!" + ); + + FnValue(value) + } + + /// Get a type reference representing the return value of the given function value. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer. + pub fn ret_type(&self) -> Type<'llvm> { + let type_ref = unsafe { LLVMGetReturnType(LLVMTypeOf(self.value_ref())) }; + Type::new(type_ref) + } + + /// Get the number of function arguments for the given function value. + pub fn args(&self) -> usize { + unsafe { LLVMCountParams(self.value_ref()) as usize } + } + + /// Get a value reference for the function argument at index `idx`. + /// + /// # Panics + /// + /// Panics if LLVM API returns a `null` pointer or indexed out of bounds. + pub fn arg(&self, idx: usize) -> Value<'llvm> { + assert!(idx < self.args()); + + let value_ref = unsafe { LLVMGetParam(self.value_ref(), idx as libc::c_uint) }; + Value::new(value_ref) + } + + /// Get the number of Basic Blocks for the given function value. + pub fn basic_blocks(&self) -> usize { + unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize } + } + + /// Verify that the given function is valid. + pub fn verify(&self) -> bool { + unsafe { + LLVMVerifyFunction( + self.value_ref(), + LLVMVerifierFailureAction::LLVMPrintMessageAction, + ) == 0 + } + } +} |