Skip to content

Commit

Permalink
Update safetyhook.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyall committed Feb 2, 2024
1 parent 2c87089 commit 1e1c500
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 69 deletions.
161 changes: 113 additions & 48 deletions external/safetyhook/safetyhook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,6 @@ VmtHook create_vmt(void* object) {


namespace safetyhook {
class UnprotectMemory {
public:
UnprotectMemory(uint8_t* address, size_t size) : m_address{address}, m_size{size} {
VirtualProtect(m_address, m_size, PAGE_EXECUTE_READWRITE, &m_protect);
}

~UnprotectMemory() { VirtualProtect(m_address, m_size, m_protect, &m_protect); }

private:
uint8_t* m_address{};
size_t m_size{};
DWORD m_protect{};
};

#pragma pack(push, 1)
struct JmpE9 {
Expand Down Expand Up @@ -404,18 +391,25 @@ static auto make_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data) {
return jmp;
}

static void emit_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data, size_t size = sizeof(JmpFF)) {
[[nodiscard]] static std::expected<void, InlineHook::Error> emit_jmp_ff(
uint8_t* src, uint8_t* dst, uint8_t* data, size_t size = sizeof(JmpFF)) {
if (size < sizeof(JmpFF)) {
return;
return std::unexpected{InlineHook::Error::not_enough_space(dst)};
}

UnprotectMemory unprotect{src, size};
auto um = unprotect(src, size);

if (!um) {
return std::unexpected{InlineHook::Error::failed_to_unprotect(src)};
}

if (size > sizeof(JmpFF)) {
std::fill_n(src, size, static_cast<uint8_t>(0x90));
}

store(src, make_jmp_ff(src, dst, data));

return {};
}
#endif

Expand All @@ -427,18 +421,25 @@ constexpr auto make_jmp_e9(uint8_t* src, uint8_t* dst) {
return jmp;
}

static void emit_jmp_e9(uint8_t* src, uint8_t* dst, size_t size = sizeof(JmpE9)) {
[[nodiscard]] static std::expected<void, InlineHook::Error> emit_jmp_e9(
uint8_t* src, uint8_t* dst, size_t size = sizeof(JmpE9)) {
if (size < sizeof(JmpE9)) {
return;
return std::unexpected{InlineHook::Error::not_enough_space(dst)};
}

UnprotectMemory unprotect{src, size};
auto um = unprotect(src, size);

if (!um) {
return std::unexpected{InlineHook::Error::failed_to_unprotect(src)};
}

if (size > sizeof(JmpE9)) {
std::fill_n(src, size, static_cast<uint8_t>(0x90));
}

store(src, make_jmp_e9(src, dst));

return {};
}

static bool decode(ZydisDecodedInstruction* ix, uint8_t* ip) {
Expand Down Expand Up @@ -493,8 +494,8 @@ InlineHook& InlineHook::operator=(InlineHook&& other) noexcept {
m_trampoline_size = other.m_trampoline_size;
m_original_bytes = std::move(other.m_original_bytes);

other.m_target = 0;
other.m_destination = 0;
other.m_target = nullptr;
other.m_destination = nullptr;
other.m_trampoline_size = 0;
}

Expand Down Expand Up @@ -630,32 +631,48 @@ std::expected<void, InlineHook::Error> InlineHook::e9_hook(const std::shared_ptr
// jmp from trampoline to original.
auto src = reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_original);
auto dst = m_target + m_original_bytes.size();
emit_jmp_e9(src, dst);

if (auto result = emit_jmp_e9(src, dst); !result) {
return std::unexpected{result.error()};
}

// jmp from trampoline to destination.
src = reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination);
dst = m_destination;

#ifdef _M_X64
auto data = reinterpret_cast<uint8_t*>(&trampoline_epilogue->destination_address);
emit_jmp_ff(src, dst, data);

if (auto result = emit_jmp_ff(src, dst, data); !result) {
return std::unexpected{result.error()};
}
#else
emit_jmp_e9(src, dst);
if (auto result = emit_jmp_e9(src, dst); !result) {
return std::unexpected{result.error()};
}
#endif

std::optional<Error> error;

// jmp from original to trampoline.
execute_while_frozen(
[this, &trampoline_epilogue] {
const auto src = m_target;
const auto dst = reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination);
emit_jmp_e9(src, dst, m_original_bytes.size());
[this, &trampoline_epilogue, &error] {
if (auto result = emit_jmp_e9(m_target,
reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size());
!result) {
error = result.error();
}
},
[this](uint32_t, HANDLE, CONTEXT& ctx) {
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_target + i, m_trampoline.data() + i);
}
});

if (error) {
return std::unexpected{*error};
}

return {};
}

Expand Down Expand Up @@ -698,22 +715,31 @@ std::expected<void, InlineHook::Error> InlineHook::ff_hook(const std::shared_ptr
auto src = reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_original);
auto dst = m_target + m_original_bytes.size();
auto data = reinterpret_cast<uint8_t*>(&trampoline_epilogue->original_address);
emit_jmp_ff(src, dst, data);

if (auto result = emit_jmp_ff(src, dst, data); !result) {
return std::unexpected{result.error()};
}

std::optional<Error> error;

// jmp from original to trampoline.
execute_while_frozen(
[this] {
const auto src = m_target;
const auto dst = m_destination;
const auto data = src + sizeof(JmpFF);
emit_jmp_ff(src, dst, data, m_original_bytes.size());
[this, &error] {
if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size());
!result) {
error = result.error();
}
},
[this](uint32_t, HANDLE, CONTEXT& ctx) {
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_target + i, m_trampoline.data() + i);
}
});

if (error) {
return std::unexpected{*error};
}

return {};
}
#endif
Expand All @@ -727,10 +753,11 @@ void InlineHook::destroy() {

execute_while_frozen(
[this] {
UnprotectMemory unprotect{m_target, m_original_bytes.size()};
std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target);
if (auto um = unprotect(m_target, m_original_bytes.size())) {
std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target);
}
},
[this](uint32_t, HANDLE, CONTEXT& ctx) {
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_trampoline.data() + i, m_target + i);
}
Expand Down Expand Up @@ -825,9 +852,9 @@ void MidHook::reset() {
}

std::expected<void, MidHook::Error> MidHook::setup(
const std::shared_ptr<Allocator>& allocator, uint8_t* target, MidHookFn destination) {
const std::shared_ptr<Allocator>& allocator, uint8_t* target, MidHookFn destination_fn) {
m_target = target;
m_destination = destination;
m_destination = destination_fn;

auto stub_allocation = allocator->allocate(asm_data.size());

Expand Down Expand Up @@ -893,7 +920,7 @@ NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAc

namespace safetyhook {
void execute_while_frozen(
const std::function<void()>& run_fn, const std::function<void(uint32_t, HANDLE, CONTEXT&)>& visit_fn) {
const std::function<void()>& run_fn, const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn) {
// Freeze all threads.
int num_threads_frozen;
auto first_run = true;
Expand Down Expand Up @@ -946,7 +973,8 @@ void execute_while_frozen(
}

if (visit_fn) {
visit_fn(thread_id, thread, thread_ctx);
visit_fn(static_cast<ThreadId>(thread_id), static_cast<ThreadHandle>(thread),
static_cast<ThreadContext>(&thread_ctx));
}

++num_threads_frozen;
Expand Down Expand Up @@ -989,21 +1017,23 @@ void execute_while_frozen(
}
}

void fix_ip(CONTEXT& ctx, uint8_t* old_ip, uint8_t* new_ip) {
void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) {
auto* ctx = reinterpret_cast<CONTEXT*>(thread_ctx);

#ifdef _M_X64
auto ip = ctx.Rip;
auto ip = ctx->Rip;
#else
auto ip = ctx.Eip;
auto ip = ctx->Eip;
#endif

if (ip == reinterpret_cast<uintptr_t>(old_ip)) {
ip = reinterpret_cast<uintptr_t>(new_ip);
}

#ifdef _M_X64
ctx.Rip = ip;
ctx->Rip = ip;
#else
ctx.Eip = ip;
ctx->Eip = ip;
#endif
}
} // namespace safetyhook
Expand Down Expand Up @@ -1066,6 +1096,41 @@ bool is_executable(uint8_t* address) {

return is_page_executable(address);
}

UnprotectMemory::~UnprotectMemory() {
if (m_address != nullptr) {
DWORD old_protection;
VirtualProtect(m_address, m_size, m_original_protection, &old_protection);
}
}

UnprotectMemory::UnprotectMemory(UnprotectMemory&& other) noexcept {
*this = std::move(other);
}

UnprotectMemory& UnprotectMemory::operator=(UnprotectMemory&& other) noexcept {
if (this != &other) {
m_address = other.m_address;
m_size = other.m_size;
m_original_protection = other.m_original_protection;
other.m_address = nullptr;
other.m_size = 0;
other.m_original_protection = 0;
}

return *this;
}

std::optional<UnprotectMemory> unprotect(uint8_t* address, size_t size) {
DWORD old_protection;

if (!VirtualProtect(address, size, PAGE_EXECUTE_READWRITE, &old_protection)) {
return {};
}

return UnprotectMemory{address, size, old_protection};
}

} // namespace safetyhook

//
Expand Down
Loading

0 comments on commit 1e1c500

Please sign in to comment.