aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs144
1 files changed, 144 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..400b2d7
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,144 @@
+use goblin::elf::{Elf, program_header::PT_LOAD};
+use std::convert::TryFrom;
+use vdso_proxy_poc::{Error, JmpPad, MapEntry, Mmap, VirtAddr};
+
+#[cfg(not(target_os = "linux"))]
+compile_error!("This only makes sense on Linux, as we are poking the vdso.");
+
+/// Find the `[vdso]` entry in `/proc/self/maps`.
+fn get_vdso() -> Result<MapEntry, Error> {
+ for line in std::fs::read_to_string("/proc/self/maps")
+ .map_err(|_| Error::FailedToReadMaps)?
+ .lines()
+ {
+ let map = MapEntry::from_line(line)?;
+ match &map.name {
+ Some(n) if n == "[vdso]" => return Ok(map),
+ _ => {}
+ }
+ }
+ Err(Error::VdsoSegmentNotFound)
+}
+
+/// Create a copy of the `vdso` memory segment. Effectively allocates memory and copies the virtual
+/// address range described by `vdso`.
+///
+/// # Safety:
+/// The caller must guarantee that the `vdso` argument describes a valid virtual address range by
+/// its `address` and `length` fields.
+#[allow(unused_unsafe)]
+unsafe fn copy_vdso(vdso: &MapEntry) -> Option<Mmap> {
+ let bytes = {
+ let ptr = vdso.addr as *const u8;
+ let len = usize::try_from(vdso.len)
+ .expect("It's required that the segment length fits into a usize!");
+ // SAFETY: Validity of ptr & len must be ensured by the caller.
+ unsafe { std::slice::from_raw_parts(ptr, len) }
+ };
+
+ Mmap::new_rwx_from(&bytes)
+}
+
+/// Find the `symbol_name` in the vdso described by the [`MapEntry`] memory segment.
+///
+/// # Safety:
+/// The caller must guarantee that the `vdso` argument describes a valid virtual address range by
+/// its `address` and `length` fields.
+///
+/// # Note:
+/// Currently the version of the symbol is not checked, technically this is an error which can be
+/// fatal in case of a binary incompatibility, but that's accepted for this PoC.
+#[allow(unused_unsafe)]
+unsafe fn get_vdso_sym(vdso: &MapEntry, symbol_name: &str) -> Result<VirtAddr, Error> {
+ // Turn `vdso` maps entry into slice of bytes.
+ let bytes = {
+ let ptr = vdso.addr as *const u8;
+ let len = usize::try_from(vdso.len)
+ .expect("It's required that the segment length fits into a usize!");
+ // SAFETY: Validity of ptr & len must be ensured by the caller.
+ unsafe { std::slice::from_raw_parts(ptr, len) }
+ };
+
+ // Parse vdso bytes as ELF.
+ let elf = Elf::parse(bytes).map_err(|_| Error::FailedToParseAsElf)?;
+
+ // Compute the dynamic shared object (dso) base address. Symbol offsets are relative to this
+ // dso base address.
+ let dso_base = {
+ let phdr_load = elf
+ .program_headers
+ .iter()
+ .find(|p| p.p_type == PT_LOAD)
+ .ok_or(Error::LoadPhdrNotFound)?;
+ vdso.addr - phdr_load.p_offset - phdr_load.p_vaddr
+ };
+ assert_ne!(dso_base, 0, "If the dso base address is 0 that means the symbols contain absolute addresses, we don't want to support that!");
+
+ // Try to find the requested symbol.
+ let sym = elf
+ .dynsyms
+ .iter()
+ .filter(|sym| sym.is_function())
+ .find(|sym| matches!(elf.dynstrtab.get_at(sym.st_name), Some(sym) if sym == symbol_name))
+ .ok_or(Error::SymbolNotFound(symbol_name.into()))?;
+
+ // Compute the absolute virtual address of the requested symbol.
+ Ok(VirtAddr(dso_base + sym.st_value))
+}
+
+/// Represent the `struct timeval` C structure (see `man 2 gettimeofday`).
+#[repr(C)]
+struct Timeval {
+ tv_sec: i64,
+ tv_usec: i64,
+}
+
+fn main() -> Result<(), Error> {
+ // This represents the _new_ vdso pages that the kernel mapped into the restoring process.
+ let orig_vdso = get_vdso()?;
+
+ // This represents the _old_ vdso pages that were captured in the memory dump of the process
+ // checkpoint.
+ //
+ // SAFETY: orig_vdso describes a valid memory region as we got it from /proc/self/maps.
+ let copy_vdso = unsafe { copy_vdso(&orig_vdso).expect("Copy of vdso must succeed!") };
+
+ let (orig_sym_addr, copy_sym_addr) = unsafe {
+ // SAFETY: orig_vdso describes a valid memory region as we got it from /proc/self/maps.
+ let orig = get_vdso_sym(&orig_vdso, "__vdso_gettimeofday")?;
+ // SAFETY: copy_vdso describes a valid and owned memory allocation.
+ let copy = get_vdso_sym(&copy_vdso.as_ref(), "__vdso_gettimeofday")?;
+
+ (orig, copy)
+ };
+
+ // As an example, install a trampoline for the `__vdso_gettimeofday` symbol. The trampoline is
+ // installed in the _old_ vdso pages, where the user code from the checkpoint image links to,
+ // and forwards the calls into the _new_ vdso pages.
+ let pad = JmpPad::to(orig_sym_addr);
+ // SAFETY: copy_sym_addr is a valid virtual address as we got it from the symbol lookup.
+ unsafe { pad.install_at(copy_sym_addr) };
+
+ let mut tv: Timeval = Timeval {
+ tv_sec: 0,
+ tv_usec: 0,
+ };
+
+ unsafe {
+ // Mimic a call to `__vdso_gettimeofday` from user code which is still linked to the _old_
+ // vdso.
+
+ // SAFETY: copy_sym_addr is a valid virtual address pointing to the `__vdso_gettimeofday`
+ // function.
+ let gettimeofday: extern "C" fn(*mut Timeval, *mut libc::c_void) -> i32 =
+ std::mem::transmute(copy_sym_addr.0 as *const ());
+
+ // Invoke the `__vdso_gettimeofday` function in the copied memory region (_old_ vdso). This
+ // should forward to the function in the original memory region.
+ gettimeofday(&mut tv as *mut Timeval, std::ptr::null_mut());
+ }
+
+ println!("Timeval tv_sec : {} tv_usec : {}", tv.tv_sec, tv.tv_usec);
+
+ Ok(())
+}