Skip to content

Commit

Permalink
Merge pull request #305 from DedalusProject/ufunc
Browse files Browse the repository at this point in the history
Broader support of custom functions
  • Loading branch information
kburns authored Dec 13, 2024
2 parents 7dc7ae6 + ba4f464 commit e141383
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 81 deletions.
3 changes: 2 additions & 1 deletion dedalus/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import sys
from functools import reduce
import numpy as np
from scipy import sparse
Expand Down Expand Up @@ -981,7 +982,7 @@ def sym_diff(self, var):

# Define aliases
for key, value in aliases.items():
exec(f"{key} = {value.__name__}")
setattr(sys.modules[__name__], key, value)

# Export aliases
__all__.extend(aliases.keys())
Expand Down
14 changes: 14 additions & 0 deletions dedalus/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,20 @@ def valid_modes(self):
# Return copy to avoid mangling cached result from coeff_layout
return valid_modes.copy()

@property
def real(self):
if self.is_real:
return self
else:
return (self + np.conj(self)) / 2

@property
def imag(self):
if self.is_real:
return 0
else:
return (self - np.conj(self)) / 2j


class Current(Operand):

Expand Down
120 changes: 71 additions & 49 deletions dedalus/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import sys
from collections import defaultdict
from functools import partial, reduce
import numpy as np
Expand Down Expand Up @@ -454,7 +455,7 @@ class GeneralFunction(NonlinearOperator, FutureField):
Notes
-----
On evaluation, this wrapper evaluates the provided funciton with the given
On evaluation, this wrapper evaluates the provided function with the given
arguments and keywords, and takes the output to be data in the specified
layout, i.e.
Expand Down Expand Up @@ -502,27 +503,69 @@ def operate(self, out):


class UnaryGridFunction(NonlinearOperator, FutureField):
"""
Wrapper for applying unary functions to fields in grid space.
This can be used with arbitrary user-defined functions, but
symbolic differentiation is only implemented for some scipy/numpy
universal functions.
supported = {ufunc.__name__: ufunc for ufunc in
(np.absolute, np.conj, np.exp, np.exp2, np.expm1,
np.log, np.log2, np.log10, np.log1p, np.sqrt, np.square,
np.sin, np.cos, np.tan, np.arcsin, np.arccos, np.arctan,
np.sinh, np.cosh, np.tanh, np.arcsinh, np.arccosh, np.arctanh,
scp.erf
)}
aliased = {'abs':np.absolute, 'conj':np.conjugate}
# Add ufuncs and shortcuts to parseables
parseables.update(supported)
parseables.update(aliased)

def __init__(self, func, arg, **kw):
if func not in self.supported.values():
raise ValueError("Unsupported ufunc: %s" %func)
#arg = Operand.cast(arg)
super().__init__(arg, **kw)
Parameters
----------
func : function
Unary function acting on grid data. Must be vectorized
and include an output array argument, e.g. func(x, out).
arg : dedalus operand
Argument field or operator.
deriv : function, optional
Symbolic derivative of func. Defaults are provided
for some common numpy/scipy ufuncs (default: None).
out : field, optional
Output field (default: new field).
Notes
-----
The supplied function must support an output argument called 'out'
and act in a vectorized fashion. The action is essentially:
func(arg['g'], out=out['g'])
"""

ufunc_derivatives = {
np.absolute: lambda x: np.sign(x),
np.sign: lambda x: 0,
np.exp: lambda x: np.exp(x),
np.exp2: lambda x: np.exp2(x) * np.log(2),
np.log: lambda x: x**(-1),
np.log2: lambda x: (x * np.log(2))**(-1),
np.log10: lambda x: (x * np.log(10))**(-1),
np.sqrt: lambda x: (1/2) * x**(-1/2),
np.square: lambda x: 2*x,
np.sin: lambda x: np.cos(x),
np.cos: lambda x: -np.sin(x),
np.tan: lambda x: np.cos(x)**(-2),
np.arcsin: lambda x: (1 - x**2)**(-1/2),
np.arccos: lambda x: -(1 - x**2)**(-1/2),
np.arctan: lambda x: (1 + x**2)**(-1),
np.sinh: lambda x: np.cosh(x),
np.cosh: lambda x: np.sinh(x),
np.tanh: lambda x: 1-np.tanh(x)**2,
np.arcsinh: lambda x: (x**2 + 1)**(-1/2),
np.arccosh: lambda x: (x**2 - 1)**(-1/2),
np.arctanh: lambda x: (1 - x**2)**(-1),
scp.erf: lambda x: 2*(np.pi)**(-1/2)*np.exp(-x**2)}

# Add ufuncs and shortcuts to aliases
aliases.update({ufunc.__name__: ufunc for ufunc in ufunc_derivatives})
aliases.update({'abs': np.absolute, 'conj': np.conjugate})

def __init__(self, func, arg, deriv=None, out=None):
super().__init__(arg, out=out)
self.func = func
if arg.tensorsig:
raise ValueError("Ufuncs not defined for non-scalar fields.")
if deriv is None and func in self.ufunc_derivatives:
self.deriv = self.ufunc_derivatives[func]
else:
self.deriv = deriv
# FutureField requirements
self.domain = arg.domain
self.tensorsig = arg.tensorsig
Expand All @@ -538,40 +581,19 @@ def _build_bases(self, arg0):
bases = arg0.domain
return bases

def new_operands(self, arg):
return UnaryGridFunction(self.func, arg)
def new_operand(self, arg):
return UnaryGridFunction(self.func, arg, deriv=self.deriv)

def reinitialize(self, **kw):
arg = self.args[0].reinitialize(**kw)
return self.new_operands(arg)
return self.new_operand(arg)

def sym_diff(self, var):
"""Symbolically differentiate with respect to specified operand."""
diff_map = {np.absolute: lambda x: np.sign(x),
np.sign: lambda x: 0,
np.exp: lambda x: np.exp(x),
np.exp2: lambda x: np.exp2(x) * np.log(2),
np.log: lambda x: x**(-1),
np.log2: lambda x: (x * np.log(2))**(-1),
np.log10: lambda x: (x * np.log(10))**(-1),
np.sqrt: lambda x: (1/2) * x**(-1/2),
np.square: lambda x: 2*x,
np.sin: lambda x: np.cos(x),
np.cos: lambda x: -np.sin(x),
np.tan: lambda x: np.cos(x)**(-2),
np.arcsin: lambda x: (1 - x**2)**(-1/2),
np.arccos: lambda x: -(1 - x**2)**(-1/2),
np.arctan: lambda x: (1 + x**2)**(-1),
np.sinh: lambda x: np.cosh(x),
np.cosh: lambda x: np.sinh(x),
np.tanh: lambda x: 1-np.tanh(x)**2,
np.arcsinh: lambda x: (x**2 + 1)**(-1/2),
np.arccosh: lambda x: (x**2 - 1)**(-1/2),
np.arctanh: lambda x: (1 - x**2)**(-1),
scp.erf: lambda x: 2*(np.pi)**(-1/2)*np.exp(-x**2)}
if self.deriv is None:
raise ValueError(f"Symbolic derivative not implemented for {self.func.__name__}.")
arg = self.args[0]
arg_diff = arg.sym_diff(var)
return diff_map[self.func](arg) * arg_diff
return self.deriv(arg) * arg.sym_diff(var)

def check_conditions(self):
# Field must be in grid layout
Expand Down Expand Up @@ -1024,7 +1046,7 @@ def matrix_coupling(self, *vars):
return self.operand.matrix_coupling(*vars)


@parseable('interpolate', 'interp')
#@parseable('interpolate', 'interp')
def interpolate(arg, **positions):
# Identify domain
domain = unify_attributes((arg,)+tuple(positions), 'domain', require=False)
Expand Down Expand Up @@ -4386,7 +4408,7 @@ def compute_cfl_frequency(self, velocity, out):

# Define aliases
for key, value in aliases.items():
exec(f"{key} = {value.__name__}")
setattr(sys.modules[__name__], key, value)

# Export aliases
__all__.extend(aliases.keys())
2 changes: 0 additions & 2 deletions dedalus/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
# Build basic parsing namespace
parseables = {}
parseables.update({name: getattr(operators, name) for name in operators.__all__})
parseables.update(operators.aliases)
parseables.update({name: getattr(arithmetic, name) for name in arithmetic.__all__})
parseables.update(arithmetic.aliases)


class ProblemBase:
Expand Down
2 changes: 1 addition & 1 deletion dedalus/tests/test_grid_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
N_range = [16]
dealias_range = [1]
dtype_range = [np.float64, np.complex128]
ufuncs = d3.UnaryGridFunction.supported.values()
ufuncs = d3.UnaryGridFunction.ufunc_derivatives.keys()


@CachedMethod
Expand Down
48 changes: 20 additions & 28 deletions docs/pages/general_functions.rst
Original file line number Diff line number Diff line change
@@ -1,44 +1,36 @@
General Functions
*****************

**Note: this documentation has not yet been updated for v3 of Dedalus.**

The ``GeneralFunction`` class enables users to simply define new explicit operators for the right-hand side and analysis tasks of their simulations.
The ``GeneralFunction`` and ``UnaryGridFunction`` classes enables users to simply define new explicit operators for the right-hand side and analysis tasks of their simulations.
Such operators can be used to apply arbitrary user-defined functions to the grid values or coefficients of some set of input fields, or even do things like introduce random data or read data from an external source.

A ``GeneralFunction`` object is instantiated with a Dedalus domain, a layout object or descriptor (e.g. ``'g'`` or ``'c'`` for grid or coefficient space), a function, a list of arguments, and a dictionary of keywords.
A ``GeneralFunction`` object is instantiated with a Dedalus distributor, domain, tensor signature, dtype, layout object or descriptor (e.g. ``'g'`` or ``'c'`` for grid or coefficient space), function, list of arguments, and dictionary of keywords.
The resulting object is a Dedalus operator that can be evaluated and composed like other Dedalus operators.
It operates by first ensuring that any arguments that are Dedalus field objects are in the specified layout, then calling the function with the specified arguments and keywords, and finally setting the result as the output data in the specified layout.

Here's an example how you can use this class to apply a nonlinear function to the grid data of a single Dedalus field.
First, we define the underlying function we want to apply to the field data -- say the error function from scipy:

.. code-block:: python
from scipy import special
A simpler option that should work for many use cases is the ``UnaryGridFunction`` class, which specifically applies a function to the grid data of a single field.
The output field's distributor, domain/bases, tensor signature, and dtype are all taken to be idential to those of the input field.
Only the function and input field need to be specified.
The function must be vectorized, take a single Numpy array as input, and include an ``out`` argument that specifies the output array.
Applying most Numpy or Scipy universal functions to a Dedalus field will automatically produce the corresponding ``UnaryGridFunction`` operator.

def erf_func(field):
# Call scipy erf function on the field's data
return special.erf(field.data)
Second, we make a wrapper that returns a ``GeneralFunction`` instance that applies ``erf_func`` to a provided field in grid space.
This function produces a Dedalus operator, so it's what we want to use on the RHS or in analysis tasks:
Here's an example of using the ``UnaryGridFunction`` class to apply a custom function to the grid data of a single Dedalus field.
First, we define the underlying function we want to apply to the field data:

.. code-block:: python
import dedalus.public as de
def erf_operator(field):
# Return GeneralFunction instance that applies erf_func in grid space
return de.operators.GeneralFunction(
field.domain,
layout = 'g',
func = erf_func,
args = (field,)
)
# Custom function acting on grid data
def custom_grid_function(x, out):
out[:] = (x + np.abs(x)) / 2
return out
Finally, we add this wrapper to the parsing namespace to make it available in string-specified equations and analysis tasks:
Second, we make a wrapper that returns a ``UnaryGridFunction`` instance that applies ``custom_grid_function`` to a specified field.
This wrapper produces a Dedalus operator, so it's what we want to use on the RHS or in analysis tasks:

.. code-block:: python
de.operators.parseables['erf'] = erf_operator
# Operator wrapper for custom function
custom_grid_operator = lambda field: d3.UnaryGridFunction(custom_grid_function, field)
# Analysis task applying custom operator to a field
snapshots.add_task(custom_grid_operator(u), name="custom(u)")
1 change: 1 addition & 0 deletions docs/pages/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Specific how-to's:
gauge_conditions
tau_method
half_dimensions
general_functions

0 comments on commit e141383

Please sign in to comment.