Skip to content

Commit

Permalink
Better handling of valid_modes across problems, and fix frechet diffe…
Browse files Browse the repository at this point in the history
…rentiation for inhomogeneous terms
  • Loading branch information
kburns committed Mar 15, 2024
1 parent 0124b1b commit a46c36d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
16 changes: 10 additions & 6 deletions dedalus/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,20 @@ def frechet_differential(self, variables, perturbations, backgrounds=None):
dtype = self.dtype
# Compute differential
epsilon = Field(dist=dist, dtype=dtype)
# d/dε F(X0 + ε*X1)
diff = self
for var, pert in zip(variables, perturbations):
diff = diff.replace(var, var + epsilon*pert)
diff = diff.sym_diff(epsilon)
diff = Operand.cast(diff, self.dist, tensorsig=tensorsig, dtype=dtype)
diff = diff.replace(epsilon, 0)
# Replace backgrounds
if backgrounds:
for var, bg in zip(variables, backgrounds):
diff = diff.replace(var, bg)
# ε -> 0
if diff:
diff = Operand.cast(diff, self.dist, tensorsig=tensorsig, dtype=dtype)
diff = diff.replace(epsilon, 0)
# Replace variables with backgrounds, if specified
if diff:
if backgrounds:
for var, bg in zip(variables, backgrounds):
diff = diff.replace(var, bg)
return diff

@property
Expand Down
16 changes: 13 additions & 3 deletions dedalus/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def add_equation(self, equation, condition="True"):
'RHS': RHS,
'condition': condition,
'tensorsig': expr.tensorsig,
'dtype': expr.dtype}
'dtype': expr.dtype,
'valid_modes': LHS.valid_modes.copy()}
self._check_equation_conditions(eqn)
self._build_matrix_expressions(eqn)
self.equations.append(eqn)
Expand Down Expand Up @@ -385,6 +386,11 @@ def build_EVP(self, eigenvalue=None, backgrounds=None, perturbations=None, **kw)
eigenvalue = self.dist.Field(name='λ')
if perturbations is None:
perturbations = [var.copy() for var in variables]
for pert, var in zip(perturbations, variables):
if var.name:
pert.name = 'δ'+var.name
for pert, var in zip(perturbations, variables):
pert.valid_modes[:] = var.valid_modes
EVP = EigenvalueProblem(perturbations, eigenvalue, **kw)
# Convert equations from IVP
for eqn in self.equations:
Expand All @@ -404,8 +410,12 @@ def build_EVP(self, eigenvalue=None, backgrounds=None, perturbations=None, **kw)
if F:
if F.has(self.time):
raise UnsupportedEquationError("Cannot convert time-dependent IVP to EVP.")
F = F.frechet_differential(variables=variables, perturbations=perturbations, backgrounds=backgrounds)
EVP.add_equation((M + L - F, 0))
dF = F.frechet_differential(variables=variables, perturbations=perturbations, backgrounds=backgrounds)
else:
dF = 0
# Add linearized equation and copy valid modes
evp_eqn = EVP.add_equation((M + L - dF, 0))
evp_eqn['valid_modes'][:] = eqn['valid_modes']
return EVP


Expand Down
5 changes: 4 additions & 1 deletion dedalus/core/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ def __init__(self, problem, **kw):
logger.debug('Beginning NLBVP instantiation')
super().__init__(problem, **kw)
self.perturbations = problem.perturbations
# Copy valid modes from variables to perturbations (may have been changed after problem instantiation)
for pert, var in zip(problem.perturbations, problem.variables):
pert.valid_modes[:] = var.valid_modes
self.iteration = 0
# Create RHS handler
F_handler = self.evaluator.add_system_handler(iter=1, group='F')
Expand Down Expand Up @@ -734,7 +737,7 @@ def evaluate_handlers(self, handlers=None, dt=0):

def log_stats(self, format=".4g"):
"""Log timing statistics with specified string formatting (optional)."""
self.run_time_end = self.wall_time
self.run_time_end = self.wall_time
start_time = self.start_time_end
logger.info(f"Final iteration: {self.iteration}")
logger.info(f"Final sim time: {self.sim_time}")
Expand Down
7 changes: 3 additions & 4 deletions dedalus/core/subsystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,7 @@ def inclusion_matrices(self, bases):
# matrix = sparse.identity(fsize, format='csr')[indices]
# return matrix.tocsr()

def valid_modes(self, field):
valid_modes = field.valid_modes
def valid_modes(self, field, valid_modes):
sp_slices = self.field_slices(field)
return valid_modes[sp_slices]

Expand Down Expand Up @@ -534,8 +533,8 @@ def build_matrices(self, names):
matrices[name] = sparse.coo_matrix((data, (rows, cols)), shape=(I, J), dtype=dtype).tocsr()

# Valid modes
valid_eqn = [self.valid_modes(eqn['LHS']) for eqn in eqns]
valid_var = [self.valid_modes(var) for var in vars]
valid_eqn = [self.valid_modes(eqn['LHS'], eqn['valid_modes']) for eqn in eqns]
valid_var = [self.valid_modes(var, var.valid_modes) for var in vars]
# Invalidate equations that fail condition test
for n, eqn_cond in enumerate(eqn_conditions):
if not eqn_cond:
Expand Down

0 comments on commit a46c36d

Please sign in to comment.