Skip to content

Commit ad076d3

Browse files
Add zelReloadDrivers(flags) API
Provides a means to re-initialize all of the drivers' library handles and DDI tables. The value of flags must match what was provided to zeInit(flags). Signed-off-by: Lisanna Dettwyler <[email protected]>
1 parent 519eed2 commit ad076d3

File tree

7 files changed

+313
-2
lines changed

7 files changed

+313
-2
lines changed

doc/loader_api.md

+6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ There are currently 3 versioned components assigned the following name strings:
2121
- `"validation layer"`
2222
- `"loader"`
2323

24+
### zelReloadDrivers
25+
26+
Close, reload, and re-initialize through zeInit all driver libraries currently loaded.
27+
28+
- __flags__ init flags that will be passed to each driver's implementation of zeInit, it should match what was previously provided at the first zeInit.
29+
2430

2531
### zelLoaderTranslateHandle
2632

include/loader/ze_loader.h

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ zelLoaderGetVersions(
3939
size_t *num_elems, //Pointer to num versions to get.
4040
zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned
4141

42+
ZE_APIEXPORT ze_result_t ZE_APICALL
43+
zelReloadDrivers(
44+
ze_init_flags_t flags); //Init flags, should match flags used in zeInit
45+
4246
typedef enum _zel_handle_type_t {
4347
ZEL_HANDLE_DRIVER,
4448
ZEL_HANDLE_DEVICE,

source/lib/ze_lib.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ zelLoaderGetVersions(
171171
#endif
172172
}
173173

174+
ze_result_t ZE_APICALL
175+
zelReloadDrivers(
176+
ze_init_flags_t flags)
177+
{
178+
#ifdef DYNAMIC_LOAD_LOADER
179+
if(nullptr == ze_lib::context->loader)
180+
return ZE_RESULT_ERROR;
181+
typedef ze_result_t (ZE_APICALL *zelReloadDriver_t)(ze_driver_handle_t hDriver);
182+
auto reloadDrivers = reinterpret_cast<zelReloadDriver_t>(
183+
GET_FUNCTION_PTR(ze_lib::context->loader, "zelReloadDriversInternal") );
184+
return reloadDrivers(flags);
185+
#else
186+
return zelReloadDriversInternal(flags);
187+
#endif
188+
}
189+
174190

175191
ze_result_t ZE_APICALL
176192
zelLoaderTranslateHandle(

source/loader/ze_loader_api.cpp

+252
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,258 @@ zelLoaderGetVersionsInternal(
7373
return ZE_RESULT_SUCCESS;
7474
}
7575

76+
ZE_DLLEXPORT ze_result_t ZE_APICALL
77+
zelReloadDriversInternal(
78+
ze_init_flags_t flags)
79+
{
80+
for( auto& drv : loader::context->zeDrivers ) {
81+
if(drv.initStatus != ZE_RESULT_SUCCESS)
82+
continue;
83+
84+
if (drv.handle) {
85+
auto free_result = FREE_DRIVER_LIBRARY( drv.handle );
86+
auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result);
87+
if (failure)
88+
return ZE_RESULT_ERROR_UNINITIALIZED;
89+
}
90+
91+
drv.handle = LOAD_DRIVER_LIBRARY( drv.name.c_str() );
92+
if (NULL == drv.handle)
93+
return ZE_RESULT_ERROR_UNINITIALIZED;
94+
95+
auto zeGetGlobalProcAddrTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
96+
GET_FUNCTION_PTR( drv.handle, "zeGetGlobalProcAddrTable") );
97+
if (!zeGetGlobalProcAddrTable)
98+
return ZE_RESULT_ERROR_UNINITIALIZED;
99+
auto zeGetGlobalProcAddrTableResult = zeGetGlobalProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Global);
100+
if (zeGetGlobalProcAddrTableResult != ZE_RESULT_SUCCESS)
101+
return zeGetGlobalProcAddrTableResult;
102+
103+
auto zeGetRTASBuilderExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASBuilderExpProcAddrTable_t>(
104+
GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderExpProcAddrTable") );
105+
if (!zeGetRTASBuilderExpProcAddrTable)
106+
return ZE_RESULT_ERROR_UNINITIALIZED;
107+
auto zeGetRTASBuilderExpProcAddrTableResult = zeGetRTASBuilderExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASBuilderExp);
108+
if (zeGetRTASBuilderExpProcAddrTableResult != ZE_RESULT_SUCCESS)
109+
return zeGetRTASBuilderExpProcAddrTableResult;
110+
111+
auto zeGetRTASParallelOperationExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASParallelOperationExpProcAddrTable_t>(
112+
GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationExpProcAddrTable") );
113+
if (!zeGetRTASParallelOperationExpProcAddrTable)
114+
return ZE_RESULT_ERROR_UNINITIALIZED;
115+
auto zeGetRTASParallelOperationExpProcAddrTableResult = zeGetRTASParallelOperationExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASParallelOperationExp);
116+
if (zeGetRTASParallelOperationExpProcAddrTableResult != ZE_RESULT_SUCCESS)
117+
return zeGetRTASParallelOperationExpProcAddrTableResult;
118+
119+
auto zeGetDriverProcAddrTable = reinterpret_cast<ze_pfnGetDriverProcAddrTable_t>(
120+
GET_FUNCTION_PTR( drv.handle, "zeGetDriverProcAddrTable") );
121+
if (!zeGetDriverProcAddrTable)
122+
return ZE_RESULT_ERROR_UNINITIALIZED;
123+
auto zeGetDriverProcAddrTableResult = zeGetDriverProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Driver);
124+
if (zeGetDriverProcAddrTableResult != ZE_RESULT_SUCCESS)
125+
return zeGetDriverProcAddrTableResult;
126+
127+
auto zeGetDriverExpProcAddrTable = reinterpret_cast<ze_pfnGetDriverExpProcAddrTable_t>(
128+
GET_FUNCTION_PTR( drv.handle, "zeGetDriverExpProcAddrTable") );
129+
if (!zeGetDriverExpProcAddrTable)
130+
return ZE_RESULT_ERROR_UNINITIALIZED;
131+
auto zeGetDriverExpProcAddrTableResult = zeGetDriverExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DriverExp);
132+
if (zeGetDriverExpProcAddrTableResult != ZE_RESULT_SUCCESS)
133+
return zeGetDriverExpProcAddrTableResult;
134+
135+
auto zeGetDeviceProcAddrTable = reinterpret_cast<ze_pfnGetDeviceProcAddrTable_t>(
136+
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceProcAddrTable") );
137+
if (!zeGetDeviceProcAddrTable)
138+
return ZE_RESULT_ERROR_UNINITIALIZED;
139+
auto zeGetDeviceProcAddrTableResult = zeGetDeviceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Device);
140+
if (zeGetDeviceProcAddrTableResult != ZE_RESULT_SUCCESS)
141+
return zeGetDeviceProcAddrTableResult;
142+
143+
auto zeGetDeviceExpProcAddrTable = reinterpret_cast<ze_pfnGetDeviceExpProcAddrTable_t>(
144+
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceExpProcAddrTable") );
145+
if (!zeGetDeviceExpProcAddrTable)
146+
return ZE_RESULT_ERROR_UNINITIALIZED;
147+
auto zeGetDeviceExpProcAddrTableResult = zeGetDeviceExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DeviceExp);
148+
if (zeGetDeviceExpProcAddrTableResult != ZE_RESULT_SUCCESS)
149+
return zeGetDeviceExpProcAddrTableResult;
150+
151+
auto zeGetContextProcAddrTable = reinterpret_cast<ze_pfnGetContextProcAddrTable_t>(
152+
GET_FUNCTION_PTR( drv.handle, "zeGetContextProcAddrTable") );
153+
if (!zeGetContextProcAddrTable)
154+
return ZE_RESULT_ERROR_UNINITIALIZED;
155+
auto zeGetContextProcAddrTableResult = zeGetContextProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Context);
156+
if (zeGetContextProcAddrTableResult != ZE_RESULT_SUCCESS)
157+
return zeGetContextProcAddrTableResult;
158+
159+
auto zeGetCommandQueueProcAddrTable = reinterpret_cast<ze_pfnGetCommandQueueProcAddrTable_t>(
160+
GET_FUNCTION_PTR( drv.handle, "zeGetCommandQueueProcAddrTable") );
161+
if (!zeGetCommandQueueProcAddrTable)
162+
return ZE_RESULT_ERROR_UNINITIALIZED;
163+
auto zeGetCommandQueueProcAddrTableResult = zeGetCommandQueueProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandQueue);
164+
if (zeGetCommandQueueProcAddrTableResult != ZE_RESULT_SUCCESS)
165+
return zeGetCommandQueueProcAddrTableResult;
166+
167+
auto zeGetCommandListProcAddrTable = reinterpret_cast<ze_pfnGetCommandListProcAddrTable_t>(
168+
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListProcAddrTable") );
169+
if (!zeGetCommandListProcAddrTable)
170+
return ZE_RESULT_ERROR_UNINITIALIZED;
171+
auto zeGetCommandListProcAddrTableResult = zeGetCommandListProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandList);
172+
if (zeGetCommandListProcAddrTableResult != ZE_RESULT_SUCCESS)
173+
return zeGetCommandListProcAddrTableResult;
174+
175+
auto zeGetCommandListExpProcAddrTable = reinterpret_cast<ze_pfnGetCommandListExpProcAddrTable_t>(
176+
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListExpProcAddrTable") );
177+
if (!zeGetCommandListExpProcAddrTable)
178+
return ZE_RESULT_ERROR_UNINITIALIZED;
179+
auto zeGetCommandListExpProcAddrTableResult = zeGetCommandListExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandListExp);
180+
if (zeGetCommandListExpProcAddrTableResult != ZE_RESULT_SUCCESS)
181+
return zeGetCommandListExpProcAddrTableResult;
182+
183+
auto zeGetEventProcAddrTable = reinterpret_cast<ze_pfnGetEventProcAddrTable_t>(
184+
GET_FUNCTION_PTR( drv.handle, "zeGetEventProcAddrTable") );
185+
if (!zeGetEventProcAddrTable)
186+
return ZE_RESULT_ERROR_UNINITIALIZED;
187+
auto zeGetEventProcAddrTableResult = zeGetEventProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Event);
188+
if (zeGetEventProcAddrTableResult != ZE_RESULT_SUCCESS)
189+
return zeGetEventProcAddrTableResult;
190+
191+
auto zeGetEventExpProcAddrTable = reinterpret_cast<ze_pfnGetEventExpProcAddrTable_t>(
192+
GET_FUNCTION_PTR( drv.handle, "zeGetEventExpProcAddrTable") );
193+
if (!zeGetEventExpProcAddrTable)
194+
return ZE_RESULT_ERROR_UNINITIALIZED;
195+
auto zeGetEventExpProcAddrTableResult = zeGetEventExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventExp);
196+
if (zeGetEventExpProcAddrTableResult != ZE_RESULT_SUCCESS)
197+
return zeGetEventExpProcAddrTableResult;
198+
199+
auto zeGetEventPoolProcAddrTable = reinterpret_cast<ze_pfnGetEventPoolProcAddrTable_t>(
200+
GET_FUNCTION_PTR( drv.handle, "zeGetEventPoolProcAddrTable") );
201+
if (!zeGetEventPoolProcAddrTable)
202+
return ZE_RESULT_ERROR_UNINITIALIZED;
203+
auto zeGetEventPoolProcAddrTableResult = zeGetEventPoolProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventPool);
204+
if (zeGetEventPoolProcAddrTableResult != ZE_RESULT_SUCCESS)
205+
return zeGetEventPoolProcAddrTableResult;
206+
207+
auto zeGetFenceProcAddrTable = reinterpret_cast<ze_pfnGetFenceProcAddrTable_t>(
208+
GET_FUNCTION_PTR( drv.handle, "zeGetFenceProcAddrTable") );
209+
if (!zeGetFenceProcAddrTable)
210+
return ZE_RESULT_ERROR_UNINITIALIZED;
211+
auto zeGetFenceProcAddrTableResult = zeGetFenceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Fence);
212+
if (zeGetFenceProcAddrTableResult != ZE_RESULT_SUCCESS)
213+
return zeGetFenceProcAddrTableResult;
214+
215+
auto zeGetImageProcAddrTable = reinterpret_cast<ze_pfnGetImageProcAddrTable_t>(
216+
GET_FUNCTION_PTR( drv.handle, "zeGetImageProcAddrTable") );
217+
if (!zeGetImageProcAddrTable)
218+
return ZE_RESULT_ERROR_UNINITIALIZED;
219+
auto zeGetImageProcAddrTableResult = zeGetImageProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Image);
220+
if (zeGetImageProcAddrTableResult != ZE_RESULT_SUCCESS)
221+
return zeGetImageProcAddrTableResult;
222+
223+
auto zeGetImageExpProcAddrTable = reinterpret_cast<ze_pfnGetImageExpProcAddrTable_t>(
224+
GET_FUNCTION_PTR( drv.handle, "zeGetImageExpProcAddrTable") );
225+
if (!zeGetImageExpProcAddrTable)
226+
return ZE_RESULT_ERROR_UNINITIALIZED;
227+
auto zeGetImageExpProcAddrTableResult = zeGetImageExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ImageExp);
228+
if (zeGetImageExpProcAddrTableResult != ZE_RESULT_SUCCESS)
229+
return zeGetImageExpProcAddrTableResult;
230+
231+
auto zeGetKernelProcAddrTable = reinterpret_cast<ze_pfnGetKernelProcAddrTable_t>(
232+
GET_FUNCTION_PTR( drv.handle, "zeGetKernelProcAddrTable") );
233+
if (!zeGetKernelProcAddrTable)
234+
return ZE_RESULT_ERROR_UNINITIALIZED;
235+
auto zeGetKernelProcAddrTableResult = zeGetKernelProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Kernel);
236+
if (zeGetKernelProcAddrTableResult != ZE_RESULT_SUCCESS)
237+
return zeGetKernelProcAddrTableResult;
238+
239+
auto zeGetKernelExpProcAddrTable = reinterpret_cast<ze_pfnGetKernelExpProcAddrTable_t>(
240+
GET_FUNCTION_PTR( drv.handle, "zeGetKernelExpProcAddrTable") );
241+
if (!zeGetKernelExpProcAddrTable)
242+
return ZE_RESULT_ERROR_UNINITIALIZED;
243+
auto zeGetKernelExpProcAddrTableResult = zeGetKernelExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.KernelExp);
244+
if (zeGetKernelExpProcAddrTableResult != ZE_RESULT_SUCCESS)
245+
return zeGetKernelExpProcAddrTableResult;
246+
247+
auto zeGetMemProcAddrTable = reinterpret_cast<ze_pfnGetMemProcAddrTable_t>(
248+
GET_FUNCTION_PTR( drv.handle, "zeGetMemProcAddrTable") );
249+
if (!zeGetMemProcAddrTable)
250+
return ZE_RESULT_ERROR_UNINITIALIZED;
251+
auto zeGetMemProcAddrTableResult = zeGetMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Mem);
252+
if (zeGetMemProcAddrTableResult != ZE_RESULT_SUCCESS)
253+
return zeGetMemProcAddrTableResult;
254+
255+
auto zeGetMemExpProcAddrTable = reinterpret_cast<ze_pfnGetMemExpProcAddrTable_t>(
256+
GET_FUNCTION_PTR( drv.handle, "zeGetMemExpProcAddrTable") );
257+
if (!zeGetMemExpProcAddrTable)
258+
return ZE_RESULT_ERROR_UNINITIALIZED;
259+
auto zeGetMemExpProcAddrTableResult = zeGetMemExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.MemExp);
260+
if (zeGetMemExpProcAddrTableResult != ZE_RESULT_SUCCESS)
261+
return zeGetMemExpProcAddrTableResult;
262+
263+
auto zeGetModuleProcAddrTable = reinterpret_cast<ze_pfnGetModuleProcAddrTable_t>(
264+
GET_FUNCTION_PTR( drv.handle, "zeGetModuleProcAddrTable") );
265+
if (!zeGetModuleProcAddrTable)
266+
return ZE_RESULT_ERROR_UNINITIALIZED;
267+
auto zeGetModuleProcAddrTableResult = zeGetModuleProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Module);
268+
if (zeGetModuleProcAddrTableResult != ZE_RESULT_SUCCESS)
269+
return zeGetModuleProcAddrTableResult;
270+
271+
auto zeGetModuleBuildLogProcAddrTable = reinterpret_cast<ze_pfnGetModuleBuildLogProcAddrTable_t>(
272+
GET_FUNCTION_PTR( drv.handle, "zeGetModuleBuildLogProcAddrTable") );
273+
if (!zeGetModuleBuildLogProcAddrTable)
274+
return ZE_RESULT_ERROR_UNINITIALIZED;
275+
auto zeGetModuleBuildLogProcAddrTableResult = zeGetModuleBuildLogProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ModuleBuildLog);
276+
if (zeGetModuleBuildLogProcAddrTableResult != ZE_RESULT_SUCCESS)
277+
return zeGetModuleBuildLogProcAddrTableResult;
278+
279+
auto zeGetPhysicalMemProcAddrTable = reinterpret_cast<ze_pfnGetPhysicalMemProcAddrTable_t>(
280+
GET_FUNCTION_PTR( drv.handle, "zeGetPhysicalMemProcAddrTable") );
281+
if (!zeGetPhysicalMemProcAddrTable)
282+
return ZE_RESULT_ERROR_UNINITIALIZED;
283+
auto zeGetPhysicalMemProcAddrTableResult = zeGetPhysicalMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.PhysicalMem);
284+
if (zeGetPhysicalMemProcAddrTableResult != ZE_RESULT_SUCCESS)
285+
return zeGetPhysicalMemProcAddrTableResult;
286+
287+
auto zeGetSamplerProcAddrTable = reinterpret_cast<ze_pfnGetSamplerProcAddrTable_t>(
288+
GET_FUNCTION_PTR( drv.handle, "zeGetSamplerProcAddrTable") );
289+
if (!zeGetSamplerProcAddrTable)
290+
return ZE_RESULT_ERROR_UNINITIALIZED;
291+
auto zeGetSamplerProcAddrTableResult = zeGetSamplerProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Sampler);
292+
if (zeGetSamplerProcAddrTableResult != ZE_RESULT_SUCCESS)
293+
return zeGetSamplerProcAddrTableResult;
294+
295+
auto zeGetVirtualMemProcAddrTable = reinterpret_cast<ze_pfnGetVirtualMemProcAddrTable_t>(
296+
GET_FUNCTION_PTR( drv.handle, "zeGetVirtualMemProcAddrTable") );
297+
if (!zeGetVirtualMemProcAddrTable)
298+
return ZE_RESULT_ERROR_UNINITIALIZED;
299+
auto zeGetVirtualMemProcAddrTableResult = zeGetVirtualMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.VirtualMem);
300+
if (zeGetVirtualMemProcAddrTableResult != ZE_RESULT_SUCCESS)
301+
return zeGetVirtualMemProcAddrTableResult;
302+
303+
auto zeGetFabricEdgeExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricEdgeExpProcAddrTable_t>(
304+
GET_FUNCTION_PTR( drv.handle, "zeGetFabricEdgeExpProcAddrTable") );
305+
if (!zeGetFabricEdgeExpProcAddrTable)
306+
return ZE_RESULT_ERROR_UNINITIALIZED;
307+
auto zeGetFabricEdgeExpProcAddrTableResult = zeGetFabricEdgeExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricEdgeExp);
308+
if (zeGetFabricEdgeExpProcAddrTableResult != ZE_RESULT_SUCCESS)
309+
return zeGetFabricEdgeExpProcAddrTableResult;
310+
311+
auto zeGetFabricVertexExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricVertexExpProcAddrTable_t>(
312+
GET_FUNCTION_PTR( drv.handle, "zeGetFabricVertexExpProcAddrTable") );
313+
if (!zeGetFabricVertexExpProcAddrTable)
314+
return ZE_RESULT_ERROR_UNINITIALIZED;
315+
auto zeGetFabricVertexExpProcAddrTableResult = zeGetFabricVertexExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricVertexExp);
316+
if (zeGetFabricVertexExpProcAddrTableResult != ZE_RESULT_SUCCESS)
317+
return zeGetFabricVertexExpProcAddrTableResult;
318+
319+
auto initResult = drv.dditable.ze.Global.pfnInit(flags);
320+
// Bail out if any drivers that previously succeeded fail
321+
if (initResult != ZE_RESULT_SUCCESS)
322+
return initResult;
323+
}
324+
325+
return ZE_RESULT_SUCCESS;
326+
}
327+
76328

77329
ZE_DLLEXPORT ze_result_t ZE_APICALL
78330
zelLoaderTranslateHandleInternal(

source/loader/ze_loader_api.h

+5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ zelLoaderGetVersionsInternal(
6868
zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned
6969

7070

71+
ZE_DLLEXPORT ze_result_t ZE_APICALL
72+
zelReloadDriversInternal(
73+
ze_init_flags_t flags);
74+
75+
7176
ZE_DLLEXPORT ze_result_t ZE_APICALL
7277
zelLoaderTranslateHandleInternal(
7378
zel_handle_type_t handleType, //Handle type

test/CMakeLists.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@ if(MSVC)
1818
target_compile_options(tests PRIVATE "/MD$<$<CONFIG:Debug>:d>")
1919
endif()
2020

21-
add_test(NAME tests COMMAND tests)
22-
set_property(TEST tests PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
21+
add_test(NAME tests_api_version COMMAND tests --gtest_filter=LoaderAPI.GivenLevelZeroLoaderPresentWhenCallingzeGetLoaderVersionsAPIThenValidVersionIsReturned)
22+
set_property(TEST tests_api_version PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
23+
add_test(NAME tests_api_reload COMMAND tests --gtest_filter=LoaderAPI.GivenInitWhenCallingzelReloadDriversThenDriversStillWork)
24+
set_property(TEST tests_api_reload PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")

test/loader_api.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,30 @@ TEST(
4242
}
4343
}
4444

45+
TEST(
46+
LoaderAPI,
47+
GivenInitWhenCallingzelReloadDriversThenDriversStillWork
48+
) {
49+
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));
50+
51+
uint32_t count = 0;
52+
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, nullptr));
53+
EXPECT_GT(count, 0);
54+
55+
std::vector<ze_driver_handle_t> hDrivers(count);
56+
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, hDrivers.data()));
57+
58+
for (auto &driver : hDrivers) {
59+
ze_driver_properties_t driverProperties;
60+
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties));
61+
}
62+
63+
EXPECT_EQ(ZE_RESULT_SUCCESS, zelReloadDrivers(0));
64+
65+
for (auto &driver : hDrivers) {
66+
ze_driver_properties_t driverProperties;
67+
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties));
68+
}
69+
}
70+
4571
} // namespace

0 commit comments

Comments
 (0)