@@ -46,11 +46,14 @@ def get_backend_string(cls) -> str:
46
46
return "CPU"
47
47
48
48
@classmethod
49
- def change_backend_gempy (cls , engine_backend : AvailableBackends , use_gpu : bool = False , dtype : Optional [str ] = None ):
50
- cls ._change_backend (engine_backend , use_pykeops = PYKEOPS , use_gpu = use_gpu , dtype = dtype )
49
+ def change_backend_gempy (cls , engine_backend : AvailableBackends , use_gpu : bool = False ,
50
+ dtype : Optional [str ] = None , grads :bool = False ):
51
+ cls ._change_backend (engine_backend , use_pykeops = PYKEOPS , use_gpu = use_gpu , dtype = dtype ,
52
+ grads = grads )
51
53
52
54
@classmethod
53
- def _change_backend (cls , engine_backend : AvailableBackends , use_pykeops : bool = False , use_gpu : bool = True , dtype : Optional [str ] = None ):
55
+ def _change_backend (cls , engine_backend : AvailableBackends , use_pykeops : bool = False ,
56
+ use_gpu : bool = True , dtype : Optional [str ] = None , grads :bool = False ):
54
57
cls .dtype = DEFAULT_TENSOR_DTYPE if dtype is None else dtype
55
58
cls .dtype_obj = cls .dtype
56
59
match engine_backend :
@@ -99,6 +102,21 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool =
99
102
cls .dtype_obj = pytorch_copy .float32 if cls .dtype == "float32" else pytorch_copy .float64
100
103
cls .tensor_types = pytorch_copy .Tensor
101
104
105
+ torch .set_num_threads (torch .get_num_threads ()) # Use all available threads
106
+ cls .COMPUTE_GRADS = grads # Store the grads setting
107
+ if grads is False :
108
+ cls ._torch_no_grad_context = torch .no_grad ()
109
+ cls ._torch_no_grad_context .__enter__ ()
110
+ else :
111
+ # If there was a previous context, exit it first
112
+ if hasattr (cls , '_torch_no_grad_context' ) and cls ._torch_no_grad_context is not None :
113
+ try :
114
+ cls ._torch_no_grad_context .__exit__ (None , None , None )
115
+ except :
116
+ pass # Context might already be exited
117
+ cls ._torch_no_grad_context = None
118
+ torch .set_grad_enabled (True )
119
+
102
120
cls .use_pykeops = use_pykeops # TODO: Make this compatible with pykeops
103
121
if (use_pykeops ):
104
122
import pykeops
0 commit comments