Skip to content

Commit

Permalink
refactor: set_auto_device (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO authored Oct 13, 2024
1 parent 15938ca commit 53654ff
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 41 deletions.
2 changes: 0 additions & 2 deletions source/MaaFramework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ target_compile_definitions(MaaFramework PRIVATE MAA_FRAMEWORK_EXPORTS)
target_link_libraries(MaaFramework PRIVATE MaaUtils LibraryHolder ${OpenCV_LIBS} fastdeploy_ppocr
ONNXRuntime::ONNXRuntime HeaderOnlyLibraries)

# clang 15之后有ranges if (CMAKE_CXX_COMPILER_ID MATCHES ".*Clang") find_package(range-v3 REQUIRED)
# target_link_libraries(MaaFramework range-v3::range-v3) endif ()
add_dependencies(MaaFramework MaaUtils LibraryHolder)

if(WITH_ADB_CONTROLLER)
Expand Down
8 changes: 0 additions & 8 deletions source/MaaFramework/Resource/OCRResMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ OCRResMgr::OCRResMgr()
LogFunc;

option_.UseOrtBackend();

set_auto_device();
}

void OCRResMgr::set_cpu()
Expand All @@ -34,12 +32,6 @@ bool OCRResMgr::set_gpu(int device_id)
return true;
}

void OCRResMgr::set_auto_device()
{
// TODO: 检查 GPU 列表,并过滤一些老旧设备
set_gpu(0);
}

bool OCRResMgr::lazy_load(const std::filesystem::path& path, bool is_base)
{
LogFunc << VAR(path) << VAR(is_base);
Expand Down
1 change: 0 additions & 1 deletion source/MaaFramework/Resource/OCRResMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class OCRResMgr : public NonCopyable

void set_cpu();
bool set_gpu(int device_id);
void set_auto_device();

bool lazy_load(const std::filesystem::path& path, bool is_base);
void clear();
Expand Down
13 changes: 0 additions & 13 deletions source/MaaFramework/Resource/ONNXResMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ MAA_RES_NS_BEGIN
// }
// }

ONNXResMgr::ONNXResMgr()
{
LogFunc;

set_auto_device();
}

void ONNXResMgr::set_cpu()
{
LogInfo;
Expand Down Expand Up @@ -114,12 +107,6 @@ bool ONNXResMgr::set_gpu(int device_id)
return true;
}

void ONNXResMgr::set_auto_device()
{
// TODO: 检查 GPU 列表,并过滤一些老旧设备
set_gpu(0);
}

bool ONNXResMgr::lazy_load(const std::filesystem::path& path, bool is_base)
{
LogFunc << VAR(path) << VAR(is_base);
Expand Down
3 changes: 0 additions & 3 deletions source/MaaFramework/Resource/ONNXResMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ class ONNXResMgr : public NonCopyable
inline static const std::filesystem::path kClassifierDir = "classify";
inline static const std::filesystem::path kDetectorDir = "detect";

ONNXResMgr();

public:
void set_cpu();
bool set_gpu(int device_id);
void set_auto_device();

bool lazy_load(const std::filesystem::path& path, bool is_base);
void clear();
Expand Down
40 changes: 28 additions & 12 deletions source/MaaFramework/Resource/ResourceMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <tuple>

#include "MaaFramework/MaaMsg.h"
#include "Utils/GpuOption.h"
#include "Utils/Logger.h"
#include "Utils/Platform.h"

Expand All @@ -15,6 +16,8 @@ ResourceMgr::ResourceMgr(MaaNotificationCallback notify, void* notify_trans_arg)

res_loader_ = std::make_unique<AsyncRunner<std::filesystem::path>>(
std::bind(&ResourceMgr::run_load, this, std::placeholders::_1, std::placeholders::_2));

check_and_set_gpu();
}

ResourceMgr::~ResourceMgr()
Expand Down Expand Up @@ -257,29 +260,42 @@ bool ResourceMgr::set_inference_device(MaaOptionValue value, MaaOptionValueSize
return false;
}

int32_t device_id = *reinterpret_cast<int*>(value);
LogInfo << VAR(device_id);
int32_t device = *reinterpret_cast<int*>(value);
LogInfo << VAR(device);

if (device_id == MaaInferenceDevice_CPU) {
if (device == MaaInferenceDevice_Auto) {
check_and_set_gpu();
}
else if (device == MaaInferenceDevice_CPU) {
onnx_res_.set_cpu();
ocr_res_.set_cpu();
}
else if (device_id == MaaInferenceDevice_Auto) {
onnx_res_.set_auto_device();
ocr_res_.set_auto_device();
}
else if (device_id >= 0) {
onnx_res_.set_gpu(device_id);
ocr_res_.set_gpu(device_id);
else if (device >= 0) {
onnx_res_.set_gpu(device);
ocr_res_.set_gpu(device);
}
else { // device_id < -2
LogError << "invalid inference device" << VAR(device_id);
else {
LogError << "invalid inference device" << VAR(device);
return false;
}

return true;
}

void ResourceMgr::check_and_set_gpu()
{
auto gpu = perfer_gpu();
if (gpu) {
int32_t gpu_id = *gpu;
onnx_res_.set_gpu(gpu_id);
ocr_res_.set_gpu(gpu_id);
}
else {
onnx_res_.set_cpu();
ocr_res_.set_cpu();
}
}

bool ResourceMgr::run_load(typename AsyncRunner<std::filesystem::path>::Id id, std::filesystem::path path)
{
LogFunc << VAR(id) << VAR(path);
Expand Down
2 changes: 2 additions & 0 deletions source/MaaFramework/Resource/ResourceMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class ResourceMgr : public MaaResource
private:
bool set_inference_device(MaaOptionValue value, MaaOptionValueSize val_size);

void check_and_set_gpu();

bool run_load(typename AsyncRunner<std::filesystem::path>::Id id, std::filesystem::path path);
bool load(const std::filesystem::path& path);
bool check_stop();
Expand Down
5 changes: 5 additions & 0 deletions source/MaaUtils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ add_library(MaaUtils SHARED ${maa_utils_src} ${maa_utils_header})
target_include_directories(MaaUtils
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${MAA_PRIVATE_INC} ${MAA_PUBLIC_INC})
target_link_libraries(MaaUtils PRIVATE HeaderOnlyLibraries Boost::system ${OpenCV_LIBS})

if(WIN32)
target_link_libraries(MaaUtils PRIVATE d3d12 dxgi Cfgmgr32)
endif()

target_compile_definitions(MaaUtils PRIVATE MAA_UTILS_EXPORTS)

install(
Expand Down
16 changes: 16 additions & 0 deletions source/MaaUtils/GpuOption/GpuOptionMacOS.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifdef __APPLE__

#include "Utils/GpuOption.h"

MAA_NS_BEGIN

std::optional<int> perfer_gpu()
{
// TODO
return std::nullopt;
}

MAA_NS_END

#endif // __APPLE__

107 changes: 107 additions & 0 deletions source/MaaUtils/GpuOption/GpuOptionWin32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#ifdef _WIN32

#include "Utils/GpuOption.h"
#include "Utils/SafeWindows.hpp"

#include <initguid.h>

#include <cfgmgr32.h>
#include <d3d12.h>
#include <devpkey.h>
#include <devpropdef.h>
#include <dxgi1_6.h>

#include "Utils/Logger.h"

MAA_NS_BEGIN

std::optional<std::wstring> adapter_instance_path(LUID luid)
{
DISPLAYCONFIG_ADAPTER_NAME req {};
req.header.size = sizeof(DISPLAYCONFIG_ADAPTER_NAME);
req.header.adapterId = luid;
req.header.id = 0;
req.header.type = DISPLAYCONFIG_DEVICE_INFO_GET_ADAPTER_NAME;

LONG adpname = DisplayConfigGetDeviceInfo(&req.header);
std::ignore = adpname;

ULONG size = 0;
DEVPROPTYPE type {};
CONFIGRET err = CM_Get_Device_Interface_PropertyW(req.adapterDevicePath, &DEVPKEY_Device_InstanceId, &type, nullptr, &size, 0);
if (err != CR_BUFFER_SMALL) {
return std::nullopt;
}
if (type != DEVPROP_TYPE_STRING) {
return std::nullopt;
}

std::vector<BYTE> buf(size);
err = CM_Get_Device_Interface_PropertyW(req.adapterDevicePath, &DEVPKEY_Device_InstanceId, &type, buf.data(), &size, 0);
if (err != CR_SUCCESS) {
return std::nullopt;
}

std::wstring result(reinterpret_cast<const wchar_t*>(buf.data()), size / 2 - 1);
LogTrace << VAR(result);
return result;
}

std::optional<int> perfer_gpu()
{
IDXGIFactory4* dxgi_factory = nullptr;
OnScopeLeave([&]() {
if (dxgi_factory) {
dxgi_factory->Release();
dxgi_factory = nullptr;
}
});

HRESULT ret = CreateDXGIFactory2(0, __uuidof(IDXGIFactory4), reinterpret_cast<void**>(&dxgi_factory));
if (FAILED(ret)) {
LogError << "CreateDXGIFactory2 failed" << VAR(ret);
return false;
}

for (UINT adapter_index = 0;; ++adapter_index) {
IDXGIAdapter1* dxgi_adapter = nullptr;
OnScopeLeave([&]() {
if (dxgi_adapter) {
dxgi_adapter->Release();
dxgi_adapter = nullptr;
}
});

HRESULT hr = dxgi_factory->EnumAdapters1(adapter_index, &dxgi_adapter);
if (hr == DXGI_ERROR_NOT_FOUND) {
break;
}
else if (FAILED(hr)) {
LogError << "EnumAdapters1 failed" << VAR(hr);
continue;
}

DXGI_ADAPTER_DESC1 desc {};
hr = dxgi_adapter->GetDesc1(&desc);
if (FAILED(hr)) {
LogError << "GetDesc1 failed" << VAR(hr);
continue;
}

std::wstring gpu_desc(desc.Description);
LogTrace << VAR(adapter_index) << VAR(gpu_desc);

if (gpu_desc.find(L"NVIDIA") == std::wstring::npos) {
continue;
}
return adapter_index;

// auto instance_path = adapter_instance_path(desc.AdapterLuid);
}

return std::nullopt;
}

MAA_NS_END

#endif // _WIN32
15 changes: 15 additions & 0 deletions source/MaaUtils/GpuOption/GpuOption_NotImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#if !defined(__APPLE__) && !defined(_WIN32)

#include "Utils/GpuOption.h"

MAA_NS_BEGIN

std::optional<int> perfer_gpu()
{
// TODO
return std::nullopt;
}

MAA_NS_END

#endif // !defined(__APPLE__) && !defined(_WIN32)
4 changes: 2 additions & 2 deletions source/MaaWin32ControlUnit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ add_library(MaaWin32ControlUnit SHARED ${maa_win32_control_unit_src} ${maa_win32
target_include_directories(MaaWin32ControlUnit
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${MAA_PRIVATE_INC} ${MAA_PUBLIC_INC})

target_link_libraries(MaaWin32ControlUnit MaaUtils HeaderOnlyLibraries ${OpenCV_LIBS} ZLIB::ZLIB Boost::system)
target_link_libraries(MaaWin32ControlUnit d3d11 dxgi)
target_link_libraries(MaaWin32ControlUnit PRIVATE MaaUtils HeaderOnlyLibraries ${OpenCV_LIBS} ZLIB::ZLIB Boost::system)
target_link_libraries(MaaWin32ControlUnit PRIVATE d3d11 dxgi)

target_compile_definitions(MaaWin32ControlUnit PRIVATE MAA_CONTROL_UNIT_EXPORTS)

Expand Down
12 changes: 12 additions & 0 deletions source/include/Utils/GpuOption.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <optional>

#include "Conf/Conf.h"
#include "MaaFramework/MaaPort.h"

MAA_NS_BEGIN

MAA_UTILS_API std::optional<int> perfer_gpu();

MAA_NS_END

0 comments on commit 53654ff

Please sign in to comment.