From 5ff1250d73f11fa4e2651773a02d51d3a8d138a1 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Sun, 30 Mar 2025 22:47:04 +0200 Subject: [PATCH 1/3] allow grad on complex inputs --- pytensor/gradient.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 04572b29d0..b8000a6331 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -742,7 +742,7 @@ def handle_disconnected(var): for var in grad_dict: g = grad_dict[var] if hasattr(g.type, "dtype"): - assert g.type.dtype in pytensor.tensor.type.float_dtypes + assert g.type.dtype in pytensor.tensor.type.continuous_dtypes _rval: Sequence[Variable] = _populate_grad_dict( var_to_app_to_idx, grad_dict, _wrt, cost_name @@ -1411,7 +1411,7 @@ def try_to_copy_if_needed(var): ) if not isinstance(term.type, NullType | DisconnectedType): - if term.type.dtype not in pytensor.tensor.type.float_dtypes: + if term.type.dtype not in pytensor.tensor.type.continuous_dtypes: raise TypeError( str(node.op) + ".grad illegally " " returned an integer-valued variable." @@ -1562,7 +1562,7 @@ def _float_ones_like(x): """ dtype = x.type.dtype - if dtype not in pytensor.tensor.type.float_dtypes: + if dtype not in pytensor.tensor.type.continuous_dtypes: dtype = config.floatX return x.ones_like(dtype=dtype) From 64c4a79f55f39bdace8877bf33b9268f36df0f07 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Thu, 3 Apr 2025 13:06:49 +0200 Subject: [PATCH 2/3] extend numeric_grad to complex input functions --- pytensor/gradient.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index b8000a6331..b8268c9418 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1633,16 +1633,35 @@ def prod(inputs): rval *= i return rval + def real_array(x): + dtype = np.array(x).dtype + if str(dtype).startswith("complex"): + return ( + np.stack([np.real(x), np.imag(x)], axis=-1), + np.array([1.0, 1j]), + ) + else: + return (np.expand_dims(np.array(x), axis=-1), np.array([1.0])) + + def real_wrapper(f, c_stack): + def wrapped_f(*arr): + c_arr = [t @ s for (t, s) in zip(arr, c_stack, strict=True)] + return f(*c_arr) + + return wrapped_f + packed_pt = False if not isinstance(pt, list | tuple): pt = [pt] packed_pt = True - apt = [np.array(p) for p in pt] + apt, complex_stack = list(map(list, zip(*map(real_array, pt), strict=True))) shapes = [p.shape for p in apt] dtypes = [str(p.dtype) for p in apt] + real_f = real_wrapper(f, complex_stack) + # TODO: remove this eventually (why was this here in the first place ?) # In the case of CSM, the arguments are a mixture of floats and # integers... @@ -1677,7 +1696,7 @@ def prod(inputs): apt[i][...] = p cur_pos += p_size - f_x = f(*[p.copy() for p in apt]) + real_f_x = real_f(*[p.copy() for p in apt]) # now iterate over the elements of x, and call f on apt. x_copy = x.copy() @@ -1685,14 +1704,18 @@ def prod(inputs): x[:] = x_copy x[i] += eps - f_eps = f(*apt) + real_f_eps = real_f(*apt) # TODO: remove this when it is clear that the next # replacemement does not pose problems of its own. It was replaced # for its inability to handle complex variables. # gx[i] = numpy.asarray((f_eps - f_x) / eps) - gx[i] = (f_eps - f_x) / eps + gx[i] = (real_f_eps - real_f_x) / eps + + self.gf = [ + (p @ s).conj() for (p, s) in zip(self.gf, complex_stack, strict=True) + ] if packed_pt: self.gf = self.gf[0] @@ -1874,7 +1897,7 @@ def verify_grad( pt = [np.array(p) for p in pt] for i, p in enumerate(pt): - if p.dtype not in ("float16", "float32", "float64"): + if p.dtype not in pytensor.tensor.type.continuous_dtypes: raise TypeError( "verify_grad can work only with floating point " f'inputs, but input {i} has dtype "{p.dtype}".' From 806fad119a87572ac96b5ca6771fdb600db86216 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Mon, 14 Apr 2025 19:20:32 +0200 Subject: [PATCH 3/3] add relative tolerance error for complex types --- pytensor/gradient.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index b8268c9418..1e8aa86818 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1904,7 +1904,11 @@ def verify_grad( ) _type_tol = dict( # relative error tolerances for different types - float16=5e-2, float32=1e-2, float64=1e-4 + float16=5e-2, + float32=1e-2, + float64=1e-4, + complex64=1e-2, + complex128=1e-4, ) if abs_tol is None: