Skip to content

Commit 31a92e9

Browse files
Add test to load many kernels concurrently and check for errors
1 parent 2bdc03d commit 31a92e9

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
#
2+
# ISC License
3+
#
4+
# Copyright (c) 2025, Autonomous Vehicle
5+
# Systems Lab, University of Colorado at Boulder
6+
#
7+
# Permission to use, copy, modify, and/or distribute this software for any
8+
# purpose with or without fee is hereby granted, provided that the above
9+
# copyright notice and this permission notice appear in all copies.
10+
#
11+
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
12+
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
13+
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
14+
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
15+
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
16+
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
17+
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18+
#
19+
20+
import os
21+
import sys
22+
import time
23+
import multiprocessing as mp
24+
import pytest
25+
import traceback
26+
27+
from Basilisk import __path__
28+
from Basilisk.simulation import spiceInterface
29+
30+
r"""
31+
Unit Test for SPICE Interface Thread Safety
32+
===========================================
33+
34+
This script stress-tests the SPICE interface in parallel, reproducing
35+
the conditions of GitHub issue #220 where parallel simulations using
36+
SPICE could deadlock or corrupt data.
37+
38+
Multiple worker processes repeatedly create and destroy SpiceInterface
39+
instances, forcing concurrent kernel load/unload operations. The test
40+
passes if all workers complete without hangs or unhandled exceptions.
41+
"""
42+
43+
bskPath = __path__[0]
44+
45+
def createLoadDestroySpice(workerId, iterations, dataPath):
46+
"""
47+
Repeatedly create, reset, and destroy SpiceInterface objects.
48+
49+
This function is run in parallel by multiple processes. Each worker
50+
performs `iterations` cycles of:
51+
1. Constructing a SpiceInterface
52+
2. Configuring planet names and SPICE data path
53+
3. Calling Reset (which triggers kernel loads)
54+
4. Brief sleep to increase contention
55+
5. Deleting the interface (allowing kernels to be released)
56+
57+
Parameters
58+
----------
59+
workerId : int
60+
Identifier for this worker process.
61+
iterations : int
62+
Number of create/reset/destroy cycles to perform.
63+
dataPath : str
64+
Directory containing SPICE kernel data.
65+
66+
Returns
67+
-------
68+
dict
69+
Summary for this worker with counts of successes, failures, and
70+
a list of captured exception details.
71+
"""
72+
print(f"Worker {workerId} starting with {iterations} iterations")
73+
74+
successCount = 0
75+
failureCount = 0
76+
exceptionList = []
77+
78+
try:
79+
for iteration in range(iterations):
80+
try:
81+
# Create a new SpiceInterface
82+
spiceObj = spiceInterface.SpiceInterface()
83+
84+
# Use a fixed planet set to avoid random differences
85+
planets = ["earth", "sun"]
86+
spiceObj.addPlanetNames(planets)
87+
88+
# Configure SPICE data path and trigger kernel loads
89+
spiceObj.SPICEDataPath = dataPath
90+
spiceObj.Reset(0)
91+
92+
# Short sleep to encourage overlap among workers
93+
time.sleep(0.001)
94+
95+
# Drop reference so the object can be destroyed
96+
del spiceObj
97+
98+
successCount += 1
99+
print(
100+
f"Worker {workerId} completed iteration "
101+
f"{iteration + 1}/{iterations}"
102+
)
103+
except Exception as exc:
104+
failureCount += 1
105+
errorInfo = {
106+
"workerId": workerId,
107+
"iteration": iteration,
108+
"error": str(exc),
109+
"traceback": traceback.format_exc(),
110+
}
111+
exceptionList.append(errorInfo)
112+
print(
113+
f"Worker {workerId} failed at iteration {iteration} "
114+
f"with error: {exc}"
115+
)
116+
# Continue with next iteration
117+
continue
118+
119+
except Exception as exc:
120+
# Catch any exception outside the main loop
121+
failureCount += 1
122+
errorInfo = {
123+
"workerId": workerId,
124+
"iteration": -1, # Outside the loop
125+
"error": str(exc),
126+
"traceback": traceback.format_exc(),
127+
}
128+
exceptionList.append(errorInfo)
129+
print(
130+
f"Worker {workerId} failed with error outside iteration loop: {exc}"
131+
)
132+
133+
return {
134+
"workerId": workerId,
135+
"successCount": successCount,
136+
"failureCount": failureCount,
137+
"exceptions": exceptionList,
138+
}
139+
140+
141+
def runThreadSafetyTest(numWorkers=2, iterationsPerWorker=5):
142+
"""
143+
Run the SPICE thread-safety stress test.
144+
145+
Parameters
146+
----------
147+
numWorkers : int
148+
Number of parallel worker processes to launch.
149+
iterationsPerWorker : int
150+
Number of create/reset/destroy cycles per worker.
151+
152+
Returns
153+
-------
154+
results : dict
155+
Aggregate statistics over all workers.
156+
success : bool
157+
True if all iterations completed without failure, False otherwise.
158+
"""
159+
print(f"Starting SPICE Thread Safety Test with {numWorkers} workers")
160+
print(f"Each worker will perform {iterationsPerWorker} iterations")
161+
162+
dataPath = bskPath + "/supportData/EphemerisData/"
163+
164+
startTime = time.time()
165+
166+
workerArgs = [
167+
(workerId, iterationsPerWorker, dataPath)
168+
for workerId in range(numWorkers)
169+
]
170+
171+
with mp.Pool(processes=numWorkers) as pool:
172+
workerResults = list(pool.starmap(createLoadDestroySpice, workerArgs))
173+
174+
endTime = time.time()
175+
executionTime = endTime - startTime
176+
177+
totalSuccess = sum(r["successCount"] for r in workerResults)
178+
totalFailure = sum(r["failureCount"] for r in workerResults)
179+
allExceptions = [e for r in workerResults for e in r["exceptions"]]
180+
181+
results = {
182+
"executionTime": executionTime,
183+
"totalIterations": numWorkers * iterationsPerWorker,
184+
"successfulIterations": totalSuccess,
185+
"failedIterations": totalFailure,
186+
"exceptions": allExceptions,
187+
}
188+
189+
print("\n--- SPICE Thread Safety Test Report ---")
190+
print(f"Total execution time: {executionTime:.2f} seconds")
191+
print(f"Total iterations: {numWorkers * iterationsPerWorker}")
192+
print(f"Successful iterations: {totalSuccess}")
193+
print(f"Failed iterations: {totalFailure}")
194+
print(f"Exceptions encountered: {len(allExceptions)}")
195+
print("--------------------------------------\n")
196+
197+
if totalSuccess == 0:
198+
print("TEST FAILED: No successful iterations completed")
199+
if len(allExceptions) > 0:
200+
print("\nFirst exception details:")
201+
print(allExceptions[0]["traceback"])
202+
success = False
203+
else:
204+
success = (totalFailure == 0)
205+
if success:
206+
print("TEST PASSED: SPICE interface thread safety looks robust")
207+
else:
208+
print("TEST FAILED: Issues detected with SPICE interface thread safety")
209+
if len(allExceptions) > 0:
210+
print("\nFirst exception details:")
211+
print(allExceptions[0]["traceback"])
212+
213+
return results, success
214+
215+
216+
def _runTestWithTimeout(resultQueue, numWorkers, iterationsPerWorker):
217+
"""
218+
Helper used as a process entry point to run the test with a timeout.
219+
220+
This is defined at module level so that it is picklable by
221+
multiprocessing on all supported platforms.
222+
"""
223+
try:
224+
results, success = runThreadSafetyTest(numWorkers, iterationsPerWorker)
225+
resultQueue.put((results, success))
226+
except Exception as exc:
227+
resultQueue.put(
228+
(
229+
{
230+
"error": str(exc),
231+
"traceback": traceback.format_exc(),
232+
},
233+
False,
234+
)
235+
)
236+
237+
238+
@pytest.mark.parametrize(
239+
"numWorkers, iterationsPerWorker",
240+
[
241+
(50, 3),
242+
],
243+
)
244+
def testSpiceThreadSafety(numWorkers, iterationsPerWorker):
245+
"""
246+
Pytest entry point for the SPICE thread-safety test.
247+
248+
Parameters
249+
----------
250+
numWorkers : int
251+
Number of parallel worker processes.
252+
iterationsPerWorker : int
253+
Number of load/unload cycles per worker.
254+
"""
255+
from multiprocessing import Process, Queue
256+
import queue
257+
258+
resultQueue = Queue()
259+
testProcess = Process(
260+
target=_runTestWithTimeout,
261+
args=(resultQueue, numWorkers, iterationsPerWorker),
262+
)
263+
testProcess.start()
264+
265+
timeoutSeconds = 60
266+
testProcess.join(timeoutSeconds)
267+
268+
if testProcess.is_alive():
269+
# Hard timeout: kill the worker process and fail the test
270+
testProcess.terminate()
271+
testProcess.join(1)
272+
if testProcess.is_alive():
273+
os.kill(testProcess.pid, 9)
274+
pytest.fail(
275+
f"Thread safety test timed out after {timeoutSeconds} seconds"
276+
)
277+
278+
try:
279+
results, success = resultQueue.get(block=False)
280+
281+
if isinstance(results, dict) and "error" in results:
282+
pytest.fail(
283+
"Thread safety test failed with error: "
284+
f"{results['error']}\n{results.get('traceback')}"
285+
)
286+
287+
assert success, "Thread safety test reported thread-safety issues"
288+
assert (
289+
results["failedIterations"] == 0
290+
), "Some iterations failed in the thread-safety test"
291+
except queue.Empty:
292+
pytest.fail(
293+
"Thread safety test completed but did not return any results"
294+
)
295+
296+
297+
if __name__ == "__main__":
298+
from multiprocessing import Process, Queue
299+
import queue
300+
301+
numWorkers = 50
302+
iterationsPerWorker = 3
303+
304+
if len(sys.argv) > 1:
305+
numWorkers = int(sys.argv[1])
306+
if len(sys.argv) > 2:
307+
iterationsPerWorker = int(sys.argv[2])
308+
309+
resultQueue = Queue()
310+
testProcess = Process(
311+
target=_runTestWithTimeout,
312+
args=(resultQueue, numWorkers, iterationsPerWorker),
313+
)
314+
testProcess.start()
315+
316+
timeoutSeconds = 60
317+
testProcess.join(timeoutSeconds)
318+
319+
if testProcess.is_alive():
320+
testProcess.terminate()
321+
testProcess.join(1)
322+
if testProcess.is_alive():
323+
os.kill(testProcess.pid, 9)
324+
print(
325+
f"ERROR: Thread safety test timed out after {timeoutSeconds} seconds"
326+
)
327+
sys.exit(2)
328+
329+
try:
330+
results, success = resultQueue.get(block=False)
331+
332+
if isinstance(results, dict) and "error" in results:
333+
print(
334+
"ERROR: Thread safety test failed with error: "
335+
f"{results['error']}"
336+
)
337+
print(results.get("traceback"))
338+
sys.exit(1)
339+
340+
sys.exit(0 if success else 1)
341+
except queue.Empty:
342+
print("ERROR: Thread safety test completed but did not return results")
343+
sys.exit(1)

0 commit comments

Comments
 (0)