diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..400b2d7 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,144 @@ +use goblin::elf::{Elf, program_header::PT_LOAD}; +use std::convert::TryFrom; +use vdso_proxy_poc::{Error, JmpPad, MapEntry, Mmap, VirtAddr}; + +#[cfg(not(target_os = "linux"))] +compile_error!("This only makes sense on Linux, as we are poking the vdso."); + +/// Find the `[vdso]` entry in `/proc/self/maps`. +fn get_vdso() -> Result<MapEntry, Error> { + for line in std::fs::read_to_string("/proc/self/maps") + .map_err(|_| Error::FailedToReadMaps)? + .lines() + { + let map = MapEntry::from_line(line)?; + match &map.name { + Some(n) if n == "[vdso]" => return Ok(map), + _ => {} + } + } + Err(Error::VdsoSegmentNotFound) +} + +/// Create a copy of the `vdso` memory segment. Effectively allocates memory and copies the virtual +/// address range described by `vdso`. +/// +/// # Safety: +/// The caller must guarantee that the `vdso` argument describes a valid virtual address range by +/// its `address` and `length` fields. +#[allow(unused_unsafe)] +unsafe fn copy_vdso(vdso: &MapEntry) -> Option<Mmap> { + let bytes = { + let ptr = vdso.addr as *const u8; + let len = usize::try_from(vdso.len) + .expect("It's required that the segment length fits into a usize!"); + // SAFETY: Validity of ptr & len must be ensured by the caller. + unsafe { std::slice::from_raw_parts(ptr, len) } + }; + + Mmap::new_rwx_from(&bytes) +} + +/// Find the `symbol_name` in the vdso described by the [`MapEntry`] memory segment. +/// +/// # Safety: +/// The caller must guarantee that the `vdso` argument describes a valid virtual address range by +/// its `address` and `length` fields. +/// +/// # Note: +/// Currently the version of the symbol is not checked, technically this is an error which can be +/// fatal in case of a binary incompatibility, but that's accepted for this PoC. +#[allow(unused_unsafe)] +unsafe fn get_vdso_sym(vdso: &MapEntry, symbol_name: &str) -> Result<VirtAddr, Error> { + // Turn `vdso` maps entry into slice of bytes. + let bytes = { + let ptr = vdso.addr as *const u8; + let len = usize::try_from(vdso.len) + .expect("It's required that the segment length fits into a usize!"); + // SAFETY: Validity of ptr & len must be ensured by the caller. + unsafe { std::slice::from_raw_parts(ptr, len) } + }; + + // Parse vdso bytes as ELF. + let elf = Elf::parse(bytes).map_err(|_| Error::FailedToParseAsElf)?; + + // Compute the dynamic shared object (dso) base address. Symbol offsets are relative to this + // dso base address. + let dso_base = { + let phdr_load = elf + .program_headers + .iter() + .find(|p| p.p_type == PT_LOAD) + .ok_or(Error::LoadPhdrNotFound)?; + vdso.addr - phdr_load.p_offset - phdr_load.p_vaddr + }; + assert_ne!(dso_base, 0, "If the dso base address is 0 that means the symbols contain absolute addresses, we don't want to support that!"); + + // Try to find the requested symbol. + let sym = elf + .dynsyms + .iter() + .filter(|sym| sym.is_function()) + .find(|sym| matches!(elf.dynstrtab.get_at(sym.st_name), Some(sym) if sym == symbol_name)) + .ok_or(Error::SymbolNotFound(symbol_name.into()))?; + + // Compute the absolute virtual address of the requested symbol. + Ok(VirtAddr(dso_base + sym.st_value)) +} + +/// Represent the `struct timeval` C structure (see `man 2 gettimeofday`). +#[repr(C)] +struct Timeval { + tv_sec: i64, + tv_usec: i64, +} + +fn main() -> Result<(), Error> { + // This represents the _new_ vdso pages that the kernel mapped into the restoring process. + let orig_vdso = get_vdso()?; + + // This represents the _old_ vdso pages that were captured in the memory dump of the process + // checkpoint. + // + // SAFETY: orig_vdso describes a valid memory region as we got it from /proc/self/maps. + let copy_vdso = unsafe { copy_vdso(&orig_vdso).expect("Copy of vdso must succeed!") }; + + let (orig_sym_addr, copy_sym_addr) = unsafe { + // SAFETY: orig_vdso describes a valid memory region as we got it from /proc/self/maps. + let orig = get_vdso_sym(&orig_vdso, "__vdso_gettimeofday")?; + // SAFETY: copy_vdso describes a valid and owned memory allocation. + let copy = get_vdso_sym(©_vdso.as_ref(), "__vdso_gettimeofday")?; + + (orig, copy) + }; + + // As an example, install a trampoline for the `__vdso_gettimeofday` symbol. The trampoline is + // installed in the _old_ vdso pages, where the user code from the checkpoint image links to, + // and forwards the calls into the _new_ vdso pages. + let pad = JmpPad::to(orig_sym_addr); + // SAFETY: copy_sym_addr is a valid virtual address as we got it from the symbol lookup. + unsafe { pad.install_at(copy_sym_addr) }; + + let mut tv: Timeval = Timeval { + tv_sec: 0, + tv_usec: 0, + }; + + unsafe { + // Mimic a call to `__vdso_gettimeofday` from user code which is still linked to the _old_ + // vdso. + + // SAFETY: copy_sym_addr is a valid virtual address pointing to the `__vdso_gettimeofday` + // function. + let gettimeofday: extern "C" fn(*mut Timeval, *mut libc::c_void) -> i32 = + std::mem::transmute(copy_sym_addr.0 as *const ()); + + // Invoke the `__vdso_gettimeofday` function in the copied memory region (_old_ vdso). This + // should forward to the function in the original memory region. + gettimeofday(&mut tv as *mut Timeval, std::ptr::null_mut()); + } + + println!("Timeval tv_sec : {} tv_usec : {}", tv.tv_sec, tv.tv_usec); + + Ok(()) +} |