Skip to content

Commit 65166d8

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Add regression test for sync deadlock (pytorch#141296)
See pytorch#140725 (comment) Running `torch.mps.synchronize()` after metal kernel resulted in infinite wait inside `[_MTLCommandBuffer waitUntilCompleted]` ``` (lldb) bt * thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP * frame #0: 0x00000001aa919084 Metal`pthread_cond_wait + 12 frame #1: 0x00000001aa78b1b4 Metal`-[_MTLCommandBuffer waitUntilCompleted] + 84 frame #2: 0x00000001032bf358 libtorch_python.dylib`torch::mps::MPSModule_deviceSynchronize(_object*, _object*) + 40 frame #3: 0x0000000100e94c20 Python`cfunction_vectorcall_NOARGS + 100 frame #4: 0x0000000100e389b8 Python`PyObject_Vectorcall + 92 frame #5: 0x0000000100f61e38 Python`_PyEval_EvalFrameDefault + 19040 frame #6: 0x0000000100f5d180 Python`PyEval_EvalCode + 200 frame #7: 0x0000000100fcd1a4 Python`run_eval_code_obj + 104 frame #8: 0x0000000100fccbe4 Python`run_mod + 168 frame #9: 0x0000000100fcb518 Python`pyrun_file + 164 frame #10: 0x0000000100fca854 Python`_PyRun_SimpleFileObject + 256 frame #11: 0x0000000100fca4e8 Python`_PyRun_AnyFileObject + 80 frame #12: 0x0000000100ff2028 Python`pymain_run_file_obj + 164 frame #13: 0x0000000100ff1ce4 Python`pymain_run_file + 72 frame pytorch#14: 0x0000000100ff0f74 Python`Py_RunMain + 988 frame pytorch#15: 0x0000000100ff1564 Python`pymain_main + 304 frame pytorch#16: 0x0000000100ff1604 Python`Py_BytesMain + 40 frame pytorch#17: 0x000000019f630274 dyld`start + 2840 ``` Pull Request resolved: pytorch#141296 Approved by: https://github.com/huydhn
1 parent 25c0b91 commit 65166d8

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

test/test_mps.py

+8
Original file line numberDiff line numberDiff line change
@@ -8385,6 +8385,14 @@ def test_cumprod_dim_check(self):
83858385
self.assertRaises(IndexError, lambda: x.cumprod(2))
83868386
self.assertRaises(IndexError, lambda: x.cumprod(-3))
83878387

8388+
def test_do_sync_thrice_its_all_right(self):
8389+
# Regression test for https://github.com/pytorch/pytorch/commit/9bc9d4cdb4355a385a7d7959f07d04d1648d6904
8390+
# That caused sync calls to deadlock
8391+
x = torch.nextafter(torch.ones(1024, device='mps'), torch.zeros(1024, device='mps'))
8392+
for _ in range(3):
8393+
torch.mps.synchronize()
8394+
self.assertLess(x.sum().item(), x.numel())
8395+
83888396
class TestLogical(TestCaseMPS):
83898397
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
83908398
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)

0 commit comments

Comments
 (0)