CageTheUnicorn/ctu.py

1248 lines
34 KiB
Python

import gzip, math, os, os.path, re, signal, struct, sys, yaml
from cmd import Cmd
import lz4.block
import colorama
from colorama import Fore, Back, Style
from unicorn import *
from unicorn.arm64_const import *
from capstone import *
from ceval import ceval, compile
import util
from util import *
import inlines, relocation
from svc import SvcHandler
from threadmanager import ThreadManager
import mmio
TRACE_NONE = 0
TRACE_INSTRUCTION = 1
TRACE_BLOCK = 2
TRACE_FUNCTION = 4
TRACE_MEMORY = 8
TRACE_MEMCHECK = 16
def colorDepth(depth):
colors = [Fore.RED, Fore.WHITE, Fore.GREEN, Fore.YELLOW, Style.BRIGHT + Fore.BLUE, Fore.MAGENTA, Fore.CYAN]
return colors[depth % len(colors)]
INSN_PER_SLICE = 100000000 # How many instructions to execute per thread slice
class HandleJar(object):
def __init__(self, ctu):
self.ctu = ctu
self.jar = {}
def __setitem__(self, handle, obj):
self.jar[handle] = obj
def __getitem__(self, handle):
if handle in self.jar:
return self.jar[handle]
print '~~ Unknown handle 0x%08x ~~' % handle
self.ctu.debugbreak()
return None
def __delitem__(self, handle):
del self.jar[handle]
def __contains__(self, handle):
return handle in self.jar
def items(self):
return self.jar.items()
def replace(self, old, new):
self.jar = {k:v if v is not old else new for k, v in self.jar.items()}
class CTU(Cmd, object):
def __init__(self, flags=0):
Cmd.__init__(self)
colorama.init()
self.initialized = False
self.exiting = False
self.firstLoad = True
IPCMessage.ctu = self
self.flags = 0
self.sublevel = 0
self.breakpoints = set()
self.watchpoints = []
self.terminateOnFullSleep = False # Terminate when all threads go to sleep
self.mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM)
self.mu.hook_add(UC_HOOK_CODE, self.hook_insn_bytes)
self.mu.hook_add(UC_HOOK_BLOCK, self.trace_block)
self.mu.hook_add(UC_HOOK_MEM_READ, self.trace_mem_read)
self.mu.hook_add(UC_HOOK_MEM_WRITE, self.trace_mem_write)
self.mu.hook_add(UC_HOOK_MEM_READ_UNMAPPED, self.trace_unmapped)
self.mu.hook_add(UC_HOOK_MEM_WRITE_UNMAPPED, self.trace_unmapped)
self.mu.hook_add(UC_HOOK_MEM_FETCH_UNMAPPED, self.trace_unmapped)
self.insnhooks = {}
self.fetchhooks = {}
self.termaddr = 1 << 61 # Pseudoaddress upon which to terminate execution
self.mu.mem_map(self.termaddr, 0x1000)
self.mu.mem_write(self.termaddr, '\x1F\x20\x03\xD5') # NOP
for i in xrange(30):
self.hookinsn(0xD53BD060 + i, (lambda i: lambda _, __: self.tlshook(i))(i))
self.svch = SvcHandler(self)
self.mappings = []
self.reset()
self.enableFP()
self.mu.mem_map(inlines.magicBase, 0x1000)
self.execfunc = None
self.initialized = True
def reset(self):
self.debugging = False
self.started = False
self.restarting = False
self.singlestep = False
self.mainGlobalScope = None
self.skipbp = False
for addr, size in self.mappings:
self.mu.mem_unmap(addr, size)
self.mappings = []
self.checkmaps = {}
self.checktriggers = []
self.usHeapSize = 0
self.mmiobase = 1 << 58
self.mmiosize = 0
self.mmiomap = []
for cls in mmio.mmioClasses:
self.mmiomap.append((cls.physbase, self.mmiobase + self.mmiosize, cls.size, cls(self)))
self.mmiosize += cls.size
self.map(self.mmiobase, self.mmiosize)
self.writehooks = {}
self.readhooks = {}
self.handles = HandleJar(self)
self.handleIter = 0xd000
self.handles[0xFFFF8001] = Process(0x1234)
self.handles[0xDEADBEEF] = Process(0xDEAD)
self.threads = ThreadManager(self)
self.threadIter = 0
self.exports = {}
self.funcReplacements = {}
self.loadbase = 0
self.loadsize = 0
self.heapbase = 7 << 24
self.heapsize = 32 * 1024 * 1024 # 32MB
self.heapoff = 0
self.map(self.heapbase, self.heapsize)
self.stacktop = 7 << 24
self.stacksize = 8 * 1024 * 1024 # 8MB
self.map(self.stacktop - self.stacksize, self.stacksize)
self.writemem(self.heapbase, '\0' * self.heapsize, check=False)
self.writemem(self.stacktop - self.stacksize, '\0' * self.stacksize, check=False)
for i in xrange(32):
self.reg(i, 0)
@property
def threadId(self):
if self.threads.current is None:
return '?'
else:
return str(self.threads.current.id)
def newHandle(self, obj):
i = self.handleIter
self.handleIter += 1
self.handles[i] = obj
return i
def replaceHandle(self, old, new):
self.handles.replace(old, new)
def closeHandle(self, handle):
if handle == 0xDEADBEEF or handle == 0xFFFF8001:
return
elif handle in self.handles:
obj = self.handles[handle]
print 'Closing handle:', obj
if hasattr(obj, 'close'):
obj.close()
del self.handles[handle]
def map(self, base, size):
if (base & 0xFFF) != 0:
off = base & 0xFFF
base -= off
size += off
if (size & 0xFFF) != 0:
size = (size & 0xFFFFFFFFFFFFF000) + 0x1000
if (base, size) not in self.mappings:
self.mappings.append((base, size))
self.mu.mem_map(base, size)
if self.flags & TRACE_MEMCHECK:
self.checkmaps[base] = [0] * (size >> 3)
def unmap(self, base, size):
if (base & 0xFFF) != 0:
off = base & 0xFFF
base -= off
size += off
if (size & 0xFFF) != 0:
size = (size & 0xFFFFFFFFFFFFF000) + 0x1000
if (base, size) in self.mappings:
del self.mappings[self.mappings.index((base, size))]
self.mu.mem_unmap(base, size)
if self.flags & TRACE_MEMCHECK:
del self.checkmaps[base]
def getmap(self, addr):
for base, size in self.mappings:
if base <= addr < base + size:
return base, size
return -1, -1
def checkread(self, addr, size):
if not (self.flags & TRACE_MEMCHECK):
return
miss = None
base, rsize = self.getmap(addr)
for i in xrange(size):
caddr = addr + i
if not (base <= caddr < base + rsize):
base, rsize = self.getmap(caddr)
if base == -1:
continue
off = caddr - base
if (self.checkmaps[base][off >> 3] & (1 << (off & 7))) == 0:
miss = caddr
break
tlsbase = self.threads.current.tlsbase if self.threads.current is not None else 1 << 64
if miss is not None:
print '[%s:%s] Read from uninitialized memory at %s (reading %i bytes from %s)' % (self.threadId, raw(self.threads.current.lastinsn), raw(miss), size, raw(addr))
if tlsbase <= miss < tlsbase + 0x100:
self.debugbreak()
else:
for taddr, tsize in self.checktriggers:
if taddr <= miss < taddr + tsize:
self.debugbreak()
elif addr == tlsbase and size == 4:
self.checkwrite(addr, size, unset=True)
def checkwrite(self, addr, size, unset=False, trigger=False):
if not (self.flags & TRACE_MEMCHECK):
return
base, rsize = self.getmap(addr)
for i in xrange(size):
caddr = addr + i
if not (base <= caddr < base + rsize):
base, rsize = self.getmap(caddr)
if base == -1:
continue
off = caddr - base
if unset:
self.checkmaps[base][off >> 3] &= 0xFF ^ (1 << (off & 7))
else:
self.checkmaps[base][off >> 3] |= 1 << (off & 7)
if trigger:
self.checktriggers.append((addr, size))
if len(self.checktriggers) == 5:
self.checktriggers.pop(0)
def setup(self, func):
self.execfunc = func
def load(self, dn):
load = yaml.load(file(dn + '/load.yaml'))
if 'nro' in load and not 'nxo' in load:
load['nxo'] = load['nro']
elif 'nso' in load and not 'nxo' in load:
load['nxo'] = load['nso']
if 'bundle' in load:
self.loadmemory(dn + '/' + load['bundle'])
elif 'mod' in load:
self.loadmod(dn + '/' + load['mod'])
elif 'nxo' in load:
if not isinstance(load['nxo'], list):
load['nxo'] = [load['nxo']]
ibase = 0x7100000000
self.loadbase = ibase
allImports = []
for name in load['nxo']:
print 'Loading', name
fn = dn + '/' + name
if os.path.exists(fn):
imports, exports = self.loadnso(fn, loadbase=ibase)
else:
imports, exports = self.loadnro(fn + '.nro', loadbase=ibase)
self.exports.update(exports)
allImports.append(imports)
ibase += 0x100000000
self.loadsize = ibase - self.loadbase
if True:#self.firstLoad and len(load['nso']) == 1:
Address.display_specialized = False
for imports in allImports:
for name, (addr, addend) in imports.items():
if name in self.exports:
self.write64(addr, self.exports[name] + addend)
else:
print 'Unresolved import:', name
if 'maps' in load:
for name, (base, fn) in load['maps'].items():
mapLoader(dn + '/' + fn, name, base)
if self.mainGlobalScope is not None:
self.mainGlobalScope.update(util.addressTypes)
self.firstLoad = False
def runExecFunc(self):
if self.execfunc is None:
return
self.mainGlobalScope = self.execfunc.func_globals
self.execfunc(self)
def run(self, flags=0):
fl = self.flags
self.reset()
self.flags = fl | flags
self.runExecFunc()
def enableFP(self):
addr = 0
self.mu.mem_map(addr, 0x1000)
self.mu.mem_write(addr, '\x41\x10\x38\xd5\x00\x00\x01\xaa\x40\x10\x18\xd5\x40\x10\x38\xd5\xc0\x03\x5f\xd6')
assert (self.call(addr, 3 << 20) >> 20) & 3 == 3
self.mu.mem_unmap(addr, 0x1000)
def loadmod(self, fn):
data = file(fn, 'rb').read()
moff, = struct.unpack('<I', data[4:8])
assert data[moff:moff+4] == 'MOD0'
bssStart, bssEnd = struct.unpack('<II', data[moff+0x08:moff+0x10])
bssStart, bssEnd = bssStart + moff, bssEnd + moff
moff += struct.unpack('<I', data[moff+0x18:moff+0x1C])[0]
base, = struct.unpack('<Q', data[moff+0x20:moff+0x28])
overlength = 0
if bssStart < len(data):
data = data[:bssStart]
overlength = bssEnd - bssStart
else:
self.map(base + bssStart, bssEnd - bssStart)
self.map(base, len(data) + overlength)
self.writemem(base, data)
defineAddressClass('Main', base, len(data))
def loadnso(self, fn, loadbase=0x7100000000, relocate=True):
data = file(fn, 'rb').read()
assert data[0:4] == 'NSO0'
toff, tloc, tsize = struct.unpack('<III', data[0x10:0x1C])
roff, rloc, rsize = struct.unpack('<III', data[0x20:0x2C])
doff, dloc, dsize = struct.unpack('<III', data[0x30:0x3C])
bsssize, = struct.unpack('<I', data[0x3C:0x40])
text = lz4.block.decompress(data[toff:roff], uncompressed_size=tsize)
rd = lz4.block.decompress(data[roff:doff], uncompressed_size=rsize)
data = lz4.block.decompress(data[doff:], uncompressed_size=dsize)
full = text
if rloc >= len(full):
full += '\0' * (rloc - len(full))
full += rd
else:
full = full[:rloc] + rd
if dloc >= len(full):
full += '\0' * (dloc - len(full))
full += data
else:
full = full[:dloc] + data
self.map(loadbase, len(full) + bsssize)
self.writemem(loadbase, full)
defineAddressClass(fn.rsplit('/', 1)[-1].split('.', 1)[0].title(), loadbase, len(full))
if relocate:
return relocation.relocate(self, loadbase)
def loadnro(self, fn, loadbase=0x7100000000, relocate=True):
data = file(fn, 'rb').read()
assert data[0x10:0x14] == 'NRO0'
tloc, tsize, rloc, rsize, dloc, dsize = struct.unpack('<IIIIII', data[0x20:0x20 + 6 * 4])
modoff, = struct.unpack('<I', data[4:8])
assert data[modoff:modoff+4] == 'MOD0'
bssoff, bssend = struct.unpack('<II', data[modoff+8:modoff+16])
bsssize = bssend - bssoff
text = data[tloc:tloc+tsize]
rd = data[rloc:rloc+rsize]
data = data[dloc:dloc+dsize]
full = text
if rloc >= len(full):
full += '\0' * (rloc - len(full))
full += rd
else:
full = full[:rloc] + rd
if dloc >= len(full):
full += '\0' * (dloc - len(full))
full += data
else:
full = full[:dloc] + data
if len(full) < bssoff:
full += '\0' * (bssoff - len(full))
if bsssize & 0xFFF:
bsssize = (bsssize & 0xFFFFF000) + 0x1000
self.map(loadbase, len(full) + bsssize)
try:
self.writemem(loadbase, full)
except:
import traceback
traceback.print_exc()
defineAddressClass(fn.rsplit('/', 1)[-1].split('.', 1)[0].title(), loadbase, len(full))
if relocate:
return relocation.relocate(self, loadbase)
def loadmemory(self, fn):
if not os.path.isfile(fn) and os.path.isfile(fn + '.gz'):
with gzip.GzipFile(fn + '.gz', 'rb') as ifp:
with file(fn, 'wb') as ofp:
print 'Decompressing membundle'
ofp.write(ifp.read())
print 'Done!'
with file(fn, 'rb') as fp:
regions, mainbase, wkcbase = struct.unpack('<IQQ', fp.read(20))
rmap = []
for i in xrange(regions):
addr, dlen = struct.unpack('<QI', fp.read(12))
data = fp.read(dlen)
self.map(addr, dlen)
rmap.append((addr, dlen))
self.writemem(addr, data)
mainsize = 0
wkcsize = 0
inMain = inWKC = False
last = 0
rmap.sort(key=lambda x: x[0])
for (addr, dlen) in rmap:
if addr == mainbase:
inMain = True
last = addr
elif addr == wkcbase:
inWKC = True
last = addr
if (inMain or inWKC) and last != addr:
inMain = inWKC = False
elif inMain:
mainsize += dlen
last = addr + dlen
elif inWKC:
wkcsize += dlen
last = addr + dlen
defineAddressClass('Main', mainbase, mainsize)
defineAddressClass('Wkc', wkcbase, wkcsize)
def findMmioObj(self, virtaddr):
if self.mmiobase <= virtaddr < self.mmiobase + self.mmiosize:
for pbase, vbase, size, obj in self.mmiomap:
if vbase <= virtaddr < vbase + size:
return obj, virtaddr - vbase + pbase
return None, 0
def trace_mem_read(self, mu, access, addr, size, value, user_data):
obj, paddr = self.findMmioObj(addr)
if obj is not None:
nval = obj.read(paddr, size)
if nval is None:
nval = 0
if size == 1:
self.write8(addr, nval)
elif size == 2:
self.write16(addr, nval)
elif size == 4:
self.write32(addr, nval)
elif size == 8:
self.write64(addr, nval)
#if addr == 0x710062b698:
# if size == 4:
# self.write32(addr, 0xdeadbeef)
if self.flags & TRACE_MEMORY:
value = None
if size == 1:
value = self.read8(addr, check=False)
elif size == 2:
value = self.read16(addr, check=False)
elif size == 4:
value = self.read32(addr, check=False)
elif size == 8:
value = self.read64(addr, check=False)
print '[%s:%s] %i byte read from %s = %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), '0x%x' % value if value is not None else 'unmapped')
if addr in self.readhooks:
val = self.readhooks[addr](self, size)
if val is not None:
if size == 1:
self.write8(addr, val)
elif size == 2:
self.write16(addr, val)
elif size == 4:
self.write32(addr, val)
elif size == 8:
self.write64(addr, val)
print '[%s:%s] %i detoured byte read from %s = %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), '0x%x' % val if val is not None else 'unmapped')
#self.debugbreak()
self.checkread(addr, size)
def trace_mem_write(self, mu, access, addr, size, value, user_data):
obj, paddr = self.findMmioObj(addr)
if obj is not None:
obj.write(paddr, size, value)
if self.flags & TRACE_MEMORY:
print '[%s:%s] %i byte write to %s = 0x%x' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), value)
if self.flags & TRACE_MEMCHECK:
self.checkwrite(addr, size)
if addr in self.writehooks:
if size == 1:
self.write8(addr, value)
elif size == 2:
self.write16(addr, value)
elif size == 4:
self.write32(addr, value)
elif size == 8:
self.write64(addr, value)
if self.writehooks[addr](self, addr, size, value):
del self.writehooks[addr]
def trace_unmapped(self, mu, access, addr, size, value, user_data):
if access == UC_MEM_FETCH_UNMAPPED:
print '[%s:%s] Unmapped fetch of %s' % (self.threadId, raw(self.threads.current.lastinsn), raw(addr))
elif access == UC_MEM_READ_UNMAPPED:
print '[%s:%s] Unmapped %i byte read from %s' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr))
elif access == UC_MEM_WRITE_UNMAPPED:
print '[%s:%s] Unmapped %i byte write to %s = 0x%x' % (self.threadId, raw(self.threads.current.lastinsn), size, raw(addr), value)
self.debugbreak()
def trace_insn(self, mu, addr, size, user_data):
if not self.initialized:
return
for ins in self.md.disasm(str(mu.mem_read(addr, size)), addr):
print "[%s] 0x%08x: %s %s" % (self.threadId, ins.address, ins.mnemonic, ins.op_str)
print 'x0=0x%x' % self.reg(0)
def trace_block(self, mu, addr, size, user_data):
if not self.initialized or (self.flags & TRACE_BLOCK) == 0:
return
print '[%s] Block at %s' % (self.threadId, raw(addr, pad=True))
"""if addr == MainAddress(0x1ec928) or addr == MainAddress(0x1ec9e0) or addr == MainAddress(0x1ecab4):
print '\nFATAL:\n%s\n%s\n%s\n' % (
self.readmem(self.reg(0), 0x100).split('\0', 1)[0],
self.readmem(self.reg(1), 0x100).split('\0', 1)[0],
self.readmem(self.reg(2), 0x100).split('\0', 1)[0]
)
self.stop()"""
def trace_func(self, mu, addr, size, user_data):
thread = self.threads.current
if not self.initialized or thread is None:
return
if thread.blx:
thread.callstack.append(addr)
if self.flags & TRACE_FUNCTION:
plen = len('[%s]' % self.threadId + ' ' + ' ' * len(thread.callstack))
#print ' ' * plen + '-> X0 -- %s X1 -- %s' % (raw(self.reg(0)), raw(self.reg(1)))
print '[%s]' % self.threadId, ' ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '-> %s' % raw(addr), Style.RESET_ALL
thread.blx = False
insn = self.read32(addr)
bl_mask = 0b11101100 << 24
bl_match = 0b10000100 << 24
blr_mask = 0b011011111010 << 20
blr_match = 0b010001100010 << 20
ret_mask = 0b011011110110 << 20
ret_match = 0b010001100100 << 20
if (insn & bl_mask) == bl_match or (insn & blr_mask) == blr_match:
thread.blx = True
elif (insn & ret_mask) == ret_match:
if self.flags & TRACE_FUNCTION:
if len(thread.callstack):
plen = len('[%s]' % self.threadId + ' ' + ' ' * len(thread.callstack))
print '[%s]' % self.threadId, ' ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '<- %s' % raw(thread.callstack.pop()), Style.RESET_ALL
#print ' ' * plen + '<- X0 -- %s' % raw(self.reg(0))
elif len(thread.callstack):
thread.callstack.pop()
def hook_insn_bytes(self, mu, addr, size, user_data):
self.threads.switched = False
self.threadIter += 1
if self.threadIter >= INSN_PER_SLICE:
self.threadIter = 0
if self.threads.next(pcOffset=0):
return
thread = self.threads.current
if addr in inlines.reverse:
inlines.reverse[addr](self)
self.pc = self.reg(30)
self.threads.current.blx = False
return
elif addr in self.funcReplacements:
func = self.funcReplacements[addr]
func()
if self.pc == addr:
self.pc = self.reg(30)
if thread.blx:
thread.blx = False
elif self.flags & TRACE_FUNCTION:
if len(thread.callstack):
print '[%s]' % self.threadId, ' ' * len(thread.callstack) + colorDepth(len(thread.callstack)) + '<- %s' % raw(thread.callstack.pop()), Style.RESET_ALL
return
if self.restarting:
return
if addr in self.fetchhooks:
self.fetchhooks[addr]()
if self.skipbp and not self.singlestep:
self.skipbp = False
elif self.singlestep or addr in self.breakpoints:
if self.singlestep:
self.singlestep = False
else:
print 'Breakpoint at %s' % raw(addr)
self.skipbp = True
self.debugbreak()
else:
for code, func in self.watchpoints:
if func(self):
print 'Watchpoint %s triggered at %s' % (code, raw(addr))
self.skipbp = True
self.debugbreak()
break
if self.flags & TRACE_INSTRUCTION and self.flags & TRACE_FUNCTION:
if self.threads.current is not None and self.threads.current.blx:
self.trace_func(mu, addr, size, user_data)
self.trace_insn(mu, addr, size, user_data)
else:
self.trace_insn(mu, addr, size, user_data)
self.trace_func(mu, addr, size, user_data)
elif self.flags & TRACE_INSTRUCTION:
self.trace_insn(mu, addr, size, user_data)
self.trace_func(mu, addr, size, user_data)
else:
self.trace_func(mu, addr, size, user_data)
self.threads.current.lastinsn = addr
insn, = struct.unpack('<I', self.mu.mem_read(addr, 4))
if insn in self.insnhooks:
if self.insnhooks[insn](self, addr) == False:
self.pc += 4
def hookinsn(self, insn, func=None):
def sub(func):
assert insn not in self.insnhooks
self.insnhooks[insn] = func
if func is None:
return sub
sub(func)
def hookfetch(self, addr, func=None):
addr = native(addr)
def sub(func):
assert addr not in self.fetchhooks
self.fetchhooks[addr] = func
if func is None:
return sub
sub(func)
def hookread(self, addr):
addr = native(addr)
def sub(func):
assert addr not in self.readhooks
self.readhooks[addr] = func
return sub
def hookwrite(self, addr, func=None):
def sub(func):
assert addr not in self.fetchhooks
self.writehooks[addr] = func
if func is None:
return sub
sub(func)
def replaceFunction(self, addr):
addr = native(addr)
def sub(func):
regcount = func.__code__.co_argcount - 1
def dsub():
args = [self.reg(i) for i in xrange(regcount)]
ret = func(self, *args)
if isinstance(ret, tuple) or isinstance(ret, list):
for i, v in enumerate(ret):
self.reg(i, v)
elif ret is not None:
self.reg(0, ret)
self.funcReplacements[addr] = dsub
dsub.original = func
return func
return sub
def tlshook(self, reg):
self.reg(reg, self.threads.current.tlsbase)
return False
def call(self, pc, *args, **kwargs):
_start = kwargs['_start'] if '_start' in kwargs else False
if pc in self.exports:
print 'Calling', pc
pc = self.exports[pc]
thread = self.threads.create(native(pc), native(self.stacktop), *map(native, args))
if _start:
thread.regs[0+2] = 0
thread.regs[1+2] = thread.handle
if not self.started:
self.started = True
first = True
while first or (not self.exiting and (self.threads.switched or len(self.threads.running))):
first = False
self.threads.current.thaw()
try:
self.mu.emu_start(native(pc), self.termaddr + 4)
except:
import traceback
traceback.print_exc()
print 'Exception at %s' % raw(self.threads.current.lastinsn)
self.dumpregs()
break
if self.threads.current is not None:
pc = self.threads.current.regs[0]
if self.exiting:
sys.exit(0)
self.threads.clear()
self.started = False
if self.restarting:
self.restarting = False
raise Restart()
return self.mu.reg_read(UC_ARM64_REG_X0)
def stop(self):
self.mu.reg_write(UC_ARM64_REG_PC, self.termaddr)
self.exiting = True
def malloc(self, size):
self.heapoff += size
assert self.heapoff <= self.heapsize
return self.heapbase + self.heapoff - size
def free(self, ptr):
pass # Lol.
def writemem(self, addr, data, check=True):
try:
addr = native(addr)
self.mu.mem_write(addr, data)
if check:
self.checkwrite(addr, len(data))
return True
except unicorn.UcError:
return False
def write8(self, addr, data, check=True):
return self.writemem(addr, struct.pack('<B', data), check=check)
def write16(self, addr, data, check=True):
return self.writemem(addr, struct.pack('<H', data), check=check)
def write32(self, addr, data, check=True):
return self.writemem(addr, struct.pack('<I', data), check=check)
def write64(self, addr, data, check=True):
return self.writemem(addr, struct.pack('<Q', data), check=check)
def readmem(self, addr, size, check=True):
try:
addr = native(addr)
if check and self.flags & TRACE_MEMCHECK:
self.checkread(addr, size)
return str(self.mu.mem_read(addr, size))
except unicorn.UcError:
return None
def read8(self, addr, check=True):
v = self.readmem(addr, 1, check=check)
return struct.unpack('<B', v)[0] if v is not None else None
def readS8(self, addr, check=True):
v = self.readmem(addr, 1, check=check)
return struct.unpack('<b', v)[0] if v is not None else None
def read16(self, addr, check=True):
v = self.readmem(addr, 2, check=check)
return struct.unpack('<H', v)[0] if v is not None else None
def readS16(self, addr, check=True):
v = self.readmem(addr, 2, check=check)
return struct.unpack('<h', v)[0] if v is not None else None
def read32(self, addr, check=True):
v = self.readmem(addr, 4, check=check)
return struct.unpack('<I', v)[0] if v is not None else None
def readS32(self, addr, check=True):
v = self.readmem(addr, 4, check=check)
return struct.unpack('<i', v)[0] if v is not None else None
def read64(self, addr, check=True):
v = self.readmem(addr, 8, check=check)
return struct.unpack('<Q', v)[0] if v is not None else None
def readS64(self, addr, check=True):
v = self.readmem(addr, 8, check=check)
return struct.unpack('<q', v)[0] if v is not None else None
def readstring(self, addr):
if addr is None:
return None
ret = ''
while True:
c = self.read8(addr)
addr += 1
if c == 0 or c is None:
return ret
ret += chr(c)
def memregions(self):
lastend = 0
for begin, end, perms in sorted(self.mu.mem_regions(), key=lambda x: x[0]):
if begin > lastend:
yield lastend, begin, -1
yield begin, end + 1, perms
lastend = end + 1
if lastend != 1 << 64:
yield lastend, 1 << 64, -1
def reg(self, i, val=None):
sr = {'LR': 30, 'SP': 31}
for ri in xrange(32):
sr['X%i' % ri] = ri
if isinstance(i, str) and i.upper() in sr:
i = sr[i.upper()]
if i <= 28:
c = UC_ARM64_REG_X0 + i
elif i == 29 or i == 30:
c = UC_ARM64_REG_X29 + i - 29
elif i == 31:
c = UC_ARM64_REG_SP
else:
return None
if val is None:
return self.mu.reg_read(c)
else:
self.mu.reg_write(c, native(val))
return True
@property
def pc(self):
return self.mu.reg_read(UC_ARM64_REG_PC)
@pc.setter
def pc(self, val):
self.mu.reg_write(UC_ARM64_REG_PC, val)
def dumpregs(self):
sr = {30: 'LR', 31: 'SP'}
print '-' * 52
for i in xrange(0, 32, 2):
an = sr[i] if i in sr else 'X%i' % i
bn = sr[i + 1] if i + 1 in sr else 'X%i' % (i + 1)
an += ' ' * (3 - len(an))
bn += ' ' * (3 - len(bn))
print '%s - 0x%016x %s - 0x%016x' % (
an, self.reg(i),
bn, self.reg(i + 1)
)
print '-' * 52
print
def dumpmem(self, addr, size, check=False):
addr = native(addr)
data = self.readmem(addr, size, check=check)
if data is None:
print 'Unmapped memory at %s' % raw(addr)
return
data = map(ord, data)
fmt = '%%0%ix |' % (int(math.log(addr + size, 16)) + 1)
for i in xrange(0, len(data), 16):
print fmt % (addr + i),
ascii = ''
for j in xrange(16):
if i + j < len(data):
print '%02x' % data[i + j],
if 0x20 <= data[i+j] <= 0x7E:
ascii += chr(data[i+j])
else:
ascii += '.'
else:
print ' ',
ascii += ' '
if j == 7:
print '',
ascii += ' '
print '|', ascii
def reprompt(self):
if self.started:
self.prompt = '[%s] ctu %s> ' % (self.threadId, raw(self.mu.reg_read(UC_ARM64_REG_PC)))
else:
self.prompt = 'ctu> '
def debug(self, sub=False):
self.debugging = True
self.reprompt()
try:
self.sublevel += 1
while True:
try:
self.cmdloop()
break
except KeyboardInterrupt:
print
finally:
self.sublevel -= 1
if self.sublevel == 1:
self.prompt = 'ctu> '
def debugbreak(self):
try:
self.debug(sub=True)
except Restart:
self.restarting = True
return self.stop()
def print_topics(self, header, cmds, cmdlen, maxcol):
nix = 'EOF', 'b', 'c', 's', 'r', 't'
if header is not None:
Cmd.print_topics(self, header, [cmd for cmd in cmds if cmd not in nix], cmdlen, maxcol)
def do_EOF(self, line):
print
try:
if raw_input('Really exit? y/n: ').startswith('y'):
self.exiting = True
sys.exit()
except EOFError:
print
self.exiting = True
sys.exit()
def do_exit(self, line):
"""exit
Exit the debugger."""
sys.exit()
def do_start(self, line):
"""s/start
Start or restart the code."""
if self.sublevel != 1:
raise Restart()
while True:
self.reset()
try:
self.runExecFunc()
break
except Restart:
print 'got restart at', self.sublevel
continue
do_s = do_start
def do_trace(self, line):
"""t/trace (i/instruction | b/block | f/function | m/memory)
Toggles tracing of instructions, blocks, functions, or memory."""
if line.startswith('i'):
self.flags ^= TRACE_INSTRUCTION
print 'Instruction tracing', 'on' if self.flags & TRACE_INSTRUCTION else 'off'
elif line.startswith('b'):
self.flags ^= TRACE_BLOCK
print 'Block tracing', 'on' if self.flags & TRACE_BLOCK else 'off'
elif line.startswith('f'):
self.flags ^= TRACE_FUNCTION
print 'Function tracing', 'on' if self.flags & TRACE_FUNCTION else 'off'
elif line.startswith('m'):
self.flags ^= TRACE_MEMORY
print 'Memory tracing', 'on' if self.flags & TRACE_MEMORY else 'off'
else:
print 'Unknown trace flag'
do_t = do_trace
def do_memcheck(self, line):
"""mc/memcheck
Toggles memory access validations."""
self.flags ^= TRACE_MEMCHECK
print 'Memcheck', 'on' if self.flags & TRACE_MEMCHECK else 'off'
do_mc = do_memcheck
def do_break(self, addr):
"""b/break [name]
Without `name`, list breakpoints.
Given a symbol name or address, toggle breakpoint."""
if addr == '':
print 'Breakpoints:'
for addr in self.breakpoints:
print '*', addr
return
try:
addr = raw(addr)
except BadAddr:
print 'Invalid address/symbol'
return
if addr in self.breakpoints:
print 'Removing breakpoint at %s' % addr
self.breakpoints.remove(addr)
else:
print 'Breaking at %s' % addr
self.breakpoints.add(addr)
do_b = do_break
def complete_break(self, text, line, begidx, endidx):
ftext = line.split(' ', 1)[1] if ' ' in line else ''
cut = len(ftext) - len(text)
return [sym[cut:] for sym in symbols.keys() if sym.startswith(ftext)]
complete_b = complete_break
def do_bt(self, line):
"""bt
Prints the call stack."""
print 'Call stack:'
for i, x in enumerate(self.threads.current.callstack[::-1]):
print '%03i: %s' % (i, raw(x))
def do_sym(self, name):
"""sym <name>
Prints the address of a given symbol."""
try:
print raw(name)
except BadAddr:
print 'Invalid address/symbol'
complete_sym = complete_break
def do_continue(self, line):
"""c/continue
Continues execution of the code."""
if self.sublevel == 1:
print 'Not running'
else:
return True
do_c = do_continue
def do_next(self, line):
"""n/next
Step to the next instruction."""
if self.sublevel == 1:
print 'Not running'
else:
self.singlestep = True
return True
do_n = do_next
def do_regs(self, line):
"""r/reg/regs [reg [value]]
No parameters: Display registers.
Reg parameter: Display one register.
Otherwise: Assign a value (always hex, or a symbol) to a register."""
if line == '':
return self.dumpregs()
elif ' ' in line:
r, v = line.split(' ', 1)
try:
v = raw(v)
if self.reg(r, v) is None:
print 'Invalid register'
except BadAddr:
print 'Invalid address/Symbol'
else:
v = self.reg(line)
if v is False:
print 'Invalid register'
else:
print '0x%016x' % v
do_r = do_reg = do_regs
def do_exec(self, line):
"""x/exec <code>
Evaluates a given line of C."""
try:
val = ceval(line, self)
except:
import traceback
traceback.print_exc()
print 'Execution failed'
return
if val is not None:
print '0x%x' % val
do_x = do_exec
def do_dump(self, line):
"""dump <address> [size]
Dumps `size` (default: 0x100) bytes of memory at an address.
If the address takes the form `*register` (e.g. `*X1`) then the value of that register will be used."""
line = list(line.split(' '))
if len(line[0]) == 0:
print 'No address'
elif len(line) <= 2:
if len(line[0]) and line[0][0] == '*':
line[0] = self.reg(line[0][1:])
if line[0] is None:
print 'Invalid register'
return
else:
try:
line[0] = raw(line[0])
except BadAddr:
print 'Invalid address/symbol'
return
if len(line) == 2:
line[1] = parseInt(line[1])
if line[1] is None or line[1] >= 0x10000:
print 'Invalid size'
return
self.dumpmem(line[0], 0x100 if len(line) == 1 else line[1])
else:
print 'Too many parameters'
def do_save(self, line):
"""save <address> <size> <fn>
Writes `size` bytes of memory to a file.
If the address or size takes the form `*register` (e.g. `*X1`) then the value of that register will be used."""
line = list(line.split(' '))
addr, size, fn = line
addr = self.reg(addr[1:]) if addr.startswith('*') else parseInt(addr)
size = self.reg(size[1:]) if size.startswith('*') else parseInt(size)
with file(fn, 'wb') as fp:
fp.write(self.readmem(addr, size))
print 'Wrote to file'
def do_ad(self, line):
"""ad
Toggle address display specialization."""
Address.display_specialized = not Address.display_specialized
print '%s specialized address display' % ('Enabled' if Address.display_specialized else 'Disabled')
self.reprompt()
def do_watch(self, line):
"""w/watch [expression]
Breaks when expression evaluates to true.
Without an expression, list existing watchpoints."""
if line == '':
print 'Watchpoints:'
for code, _ in self.watchpoints:
print '*', code
return
if line in [code for code, _ in self.watchpoints]:
self.watchpoints = [(code, func) for code, func in self.watchpoints if code != line]
print 'Watchpoint deleted'
else:
self.watchpoints.append((line, compile(line)))
print 'Watchpoint added'
do_w = do_watch
def do_memregions(self, line):
"""mr/memregions
Displays mapped memory regions."""
print 'Mapped memory regions'
print '---------------------'
for begin, end, perms in self.memregions():
if perms != -1:
print '%016x - %016x' % (begin, end)
do_mr = do_memregions
def debug(*flags):
def sub(func):
ctu = CTU()
ctu.setup(func)
ctu.flags |= reduce(lambda a, x: a | x, flags, TRACE_NONE)
ctu.debug()
return func
if len(flags) == 1 and callable(flags[0]):
func = flags[0]
flags = [TRACE_NONE]
return sub(func)
else:
return sub
def run(*flags):
def sub(func):
ctu = CTU()
ctu.setup(func)
ctu.flags |= reduce(lambda a, x: a | x, flags, TRACE_NONE)
ctu.run()
return func
if len(flags) == 1 and callable(flags[0]):
func = flags[0]
flags = [TRACE_NONE]
return sub(func)
else:
return sub