Skip to content
This repository was archived by the owner on Nov 11, 2021. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions docs/gt_frontend/guides/fvm/fvm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Up until now we have just considered a single control volume without actually ta
.. figure:: mesh.png
:width: 300
:align: center

Schematic of a 2D mesh

At this point different choices for the quantities to be solved for are possible. We will here use a vertex-centered approach where the unknowns are choosen to be the densities at the vertices of the mesh :math:`\rho_i^n = \rho^n_i(x_i)`, which are a first order approximation of the average cell density :math:`\bar \rho_i^n` appearing in the time discretized form above.
Expand All @@ -80,12 +80,12 @@ At this point different choices for the quantities to be solved for are possible
\rho_i^{n+1} &= \rho_i^{n} - \frac{\delta t}{|\mathcal{V}_i|} \int_{\partial {\mathcal{V}_i}} \rho^n \mathbf{v} \cdot \mathbf{n} \mathrm{\,dA}
\end{align}

The control volumes :math:`\mathcal{V}_i` are then constructed by joining the (bary)centers of the cells adjacent to each vertex with the midpoint of the adjacent edges. The set of control volumes form another mesh, denoted the dual mesh.
The control volumes :math:`\mathcal{V}_i` are then constructed by joining the (bary)centers of the cells adjacent to each vertex with the midpoint of the adjacent edges. The set of control volumes form another mesh, denoted the dual mesh.

.. figure:: fvm_median_dual_mesh_cv.png
:width: 300
:align: center

Schematic of the median-dual mesh in 2D. Primary mesh in black, dual mesh in blue. The control volume :math:`\mathcal{V}_i` around the vertex :math:`v_i` is constructed by joining the (bary)centers of adjacent cells with the midpoint of the outgoing edges of :math:`v_i`.

It remains to derive a discrete representation for the surface integral by first splitting the integral into its contributions on a set of segments :math:`S_j`, where each segment can be attributed to the edges adjacent to :math:`v_i`. Let :math:`|\mathcal{V}_i|` be the area of the control volume and :math:`l(i)` the number of edges adjacent to :math:`v_i` then
Expand Down Expand Up @@ -129,7 +129,63 @@ The resulting fully discrete time stepping scheme then reads

**Implementation in GT4Py**

To be written.
.. code-block:: python

@gtscript.stencil(externals={"vel": vel})
def fvm_advect(
mesh: Mesh,
rho: gtscript.Field[Vertex, dtype],
rho_next: gtscript.Field[Vertex, dtype],
volume: gtscript.Field[Vertex, dtype],
dual_normal: gtscript.Field[Edge, dtype],
dual_face_length: gtscript.Field[Edge, dtype],
face_orientation: gtscript.Field[Vertex, Edge, dtype], # either -1 or 1
#flux: gtscript.Field[Edge, dtype]
):
with computation(PARALLEL):
gtscript.Field[Edge, dtype]
# compute flux density through the intersection of the two
# control volumes around the dual cells associated with
# the vertices of `e` using an upwind scheme
with location(Edge) as e:
# upwind flux (instructive)
v1, v2 = vertices(e)
normal_velocity = dot(v, dual_normal[e]) # velocity projected onto the normal
if dot(vel, dual_normal[e]) > 0:
flux = rho[v1] * normal_velocity[e] * dual_face_length[e]
else:
flux = rho[v2] * normal_velocity[e] * dual_face_length[e]

# upwind flux (compact)
v1, v2 = vertices(e)
normal_velocity = dot(v, dual_normal[e]) # velocity projected onto the normal
flux = dual_face_area[e]*(max(0., normal_velocity)*rho[v1] + min(0., normal_velocity)*rho[v2])

# upwind flux (compact with weights)
flux = dual_face_area[e]*sum(rho, weights=[max(0., normal_velocity), min(0., normal_velocity)])

# centered flux (different flux just for comparison here)
flux = 0.5*sum(rho[v]*vel for v in vertices(e))
with location(Vertex) as v:
# compute density in the next timestep
rho_next = rho - δt/volume*sum(flux*face_orientation[v, e] for e in edges(v))

# parameters
vel = [1., 2.] # velocity
δt = 1e-6 # time step
niter = 100

# initialize mesh
# ...

# initialize fields
rho = zeros(mesh, dtype)
rho_next = zeros(mesh, dtype)
# todo: geometry: dual_volume, dual_normal, face_orientation

for i in range(niter):
fvm_advect(mesh, rho, rho_next, dual_volume, dual_normal, face_orientation, flux)
copyto(rho_next, rho)

**TODO**

Expand All @@ -143,4 +199,4 @@ Frame extension to IFS-FVM

**Notes**

Code for the construction of the dual mesh in Atlas: src/atlas/mesh/actions/BuildDualMesh.cc
Code for the construction of the dual mesh in Atlas: src/atlas/mesh/actions/BuildDualMesh.cc
23 changes: 21 additions & 2 deletions src/gt_frontend/gtscript_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
# todo(tehrengruber): document nodes
from typing import List, Union
from typing import List, Union, Optional

import gtc.common as common
from eve import Node
Expand Down Expand Up @@ -109,11 +109,31 @@ class BinaryOp(Expr):
left: Expr
right: Expr

class ListNode(Expr): # todo: this node is not valid in every context
elts: List[Expr]

class Keyword(GTScriptASTNode):
key: str
value: Expr

class Call(Expr):
args: List[Expr]
keywords: Optional[List[Keyword]]
func: str

#todo(tehrengruber: validate each keyword arg occurs only once)

def get_keyword_args_as_dict(self):
return {arg.key: arg.value for arg in (self.keywords if self.keywords else [])}

def has_keyword_arg(self, key):
return key in self.get_keyword_args_as_dict()

def get_keyword_arg(self, key):
if not self.has_keyword_arg(key):
raise ValueError(f"Call to {self.func} has no keyword argument {key}")
return self.get_keyword_args_as_dict()[key]


# TODO(tehrengruber): can be enabled as soon as eve_toolchain#58 lands
# class Call(Generic[T]):
Expand All @@ -132,7 +152,6 @@ class Generator(Expr):
generators: List[LocationComprehension]
elt: Expr


class Assign(Statement):
target: Union[Symbol, SubscriptSingle, SubscriptMultiple]
value: Expr
Expand Down
13 changes: 9 additions & 4 deletions src/gt_frontend/gtscript_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Generator,
Interval,
IterationOrder,
ListNode,
LocationComprehension,
LocationSpecification,
Pass,
Expand Down Expand Up @@ -341,18 +342,22 @@ def visit_Call(self, node: Call, *, location_stack, **kwargs):
neighbors = self.visit(
node.args[0].generators[0], **{**kwargs, "location_stack": location_stack}
)
weights = None
if node.has_keyword_arg("weights"):
weights = node.get_keyword_arg("weights")
if not isinstance(weights, ListNode):
raise ValueError(f"Weights argument to neighbor reduction must be a list")

# operand gets new location stack
new_location_stack = location_stack + [neighbors]

weights = list(self.visit(weight, **{**kwargs, "location_stack": location_stack}) for weight in weights.elts)
operand = self.visit(
node.args[0].elt, **{**kwargs, "location_stack": new_location_stack}
node.args[0].elt, **{**kwargs, "location_stack": location_stack + [neighbors]}
)

return gtir.NeighborReduce(
op=op,
operand=operand,
neighbors=neighbors,
weights=weights,
location_type=location_stack[-1].chain.elements[-1],
)

Expand Down
47 changes: 27 additions & 20 deletions src/gt_frontend/py_to_gtscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _all_subclasses(typ, *, module=None):
# map to symbols in the gtscript ast and are resolved there
assert issubclass(typ, enum.Enum)
return {typ}
elif typing_inspect.get_origin(typ) == list:
return {typing.List[sub_cls] for sub_cls in PyToGTScript._all_subclasses(typing_inspect.get_args(typ)[0], module=module)}
elif typing_inspect.is_union_type(typ):
return {
sub_cls
Expand Down Expand Up @@ -133,7 +135,11 @@ class Patterns:

BinaryOp = ast.BinOp(op=Capture("op"), left=Capture("left"), right=Capture("right"))

Call = ast.Call(args=Capture("args"), func=ast.Name(id=Capture("func")))
ListNode = ast.List(elts=Capture("elts"))

Keyword = ast.keyword(arg=Capture("key"), value=Capture("value"))

Call = ast.Call(args=Capture("args"), keywords=Capture("keywords"), func=ast.Name(id=Capture("func")))

LocationComprehension = ast.comprehension(
target=Capture("target"), iter=Capture("iterator")
Expand Down Expand Up @@ -170,7 +176,20 @@ def transform(self, node, eligible_node_types=None):
if eligible_node_types is None:
eligible_node_types = [gtscript_ast.Computation]

if isinstance(node, ast.AST):
if isinstance(node, typing.List):
# extract eligable node types which are lists
eligable_list_node_types = list(filter(lambda node_type: typing_inspect.get_origin(node_type) == list,
eligible_node_types))
if len(eligable_list_node_types) == 0:
raise ValueError(
f"Expected a list node, but got {type(node)}."
)

eligable_el_node_types = list(map(lambda list_node_type: typing_inspect.get_args(list_node_type)[0],
eligable_list_node_types))

return [self.transform(el, eligable_el_node_types) for el in node]
elif isinstance(node, ast.AST):
is_leaf_node = len(list(ast.iter_fields(node))) == 0
if is_leaf_node:
if not type(node) in self.leaf_map:
Expand All @@ -197,24 +216,12 @@ def transform(self, node, eligible_node_types=None):
name in node_type.__annotations__
), f"Invalid capture. No field named `{name}` in `{str(node_type)}`"
field_type = node_type.__annotations__[name]
if typing_inspect.get_origin(field_type) == list:
# determine eligible capture types
el_type = typing_inspect.get_args(field_type)[0]
eligible_capture_types = self._all_subclasses(el_type, module=module)

# transform captures recursively
transformed_captures[name] = []
for child_capture in capture:
transformed_captures[name].append(
self.transform(child_capture, eligible_capture_types)
)
else:
# determine eligible capture types
eligible_capture_types = self._all_subclasses(field_type, module=module)
# transform captures recursively
transformed_captures[name] = self.transform(
capture, eligible_capture_types
)
# determine eligible capture types
eligible_capture_types = self._all_subclasses(field_type, module=module)
# transform captures recursively
transformed_captures[name] = self.transform(
capture, eligible_capture_types
)
return node_type(**transformed_captures)
raise ValueError(
"Expected a node of type {}".format(
Expand Down
2 changes: 1 addition & 1 deletion src/gtc/unstructured/gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class NeighborReduce(Expr):
operand: Expr
op: ReduceOperator
neighbors: LocationComprehension
weights: Optional[List[Expr]]

@root_validator(pre=True)
def check_location_type(cls, values):
Expand Down Expand Up @@ -124,7 +125,6 @@ def check_location_type(cls, values):

return values


class VerticalDimension(Node):
pass

Expand Down
40 changes: 33 additions & 7 deletions src/gtc/unstructured/gtir_to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,32 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg
loc_comprehension[node.neighbors.name] = node.neighbors
kwargs["location_comprehensions"] = loc_comprehension

neighbor_loop_name = "neighbor_loop_"+str(node.id_attr_)

if node.weights and node.neighbors.chain.elements != [common.LocationType.Edge, common.LocationType.Vertex]:
raise ValueError("Invalid usage of weights in NeighborReduce.")

body_location = node.neighbors.chain.elements[-1]
reduce_var_name = "local" + str(node.id_attr_)
last_block.declarations.append(
nir.LocalVar(
nir.ScalarLocalVar(
name=reduce_var_name,
vtype=common.DataType.FLOAT64, # TODO
location_type=node.location_type,
)
)
if node.weights:
weights_var_name = "local_weights_" + str(node.id_attr_)
last_block.declarations.append(
nir.LocalFieldVar(
name=weights_var_name,
vtype=common.DataType.FLOAT64, # TODO
domain=body_location,
init=self.visit(node.weights, **kwargs),
max_size=2, # TODO
location_type=node.location_type,
)
)
last_block.statements.append(
nir.AssignStmt(
left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type),
Expand All @@ -142,24 +159,33 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg
location_type=node.location_type,
),
)
reduction_item = self.visit(node.operand, in_neighbor_loop=True, **kwargs)
if node.weights:
reduction_item = nir.BinaryOp(
left=nir.LocalFieldAccess(name=weights_var_name, location=nir.NeighborLoopLocationAccess(
name=neighbor_loop_name, location_type=body_location), location_type=body_location),
op=common.BinaryOperator.MUL, right=reduction_item, location_type=body_location)

reduction_intermediate = nir.BinaryOp(
left=nir.VarAccess(name=reduce_var_name, location_type=body_location),
op=self.REDUCE_OP_TO_BINOP[node.op],
right=reduction_item,
location_type=body_location,
)
body = nir.BlockStmt(
declarations=[],
statements=[
nir.AssignStmt(
left=nir.VarAccess(name=reduce_var_name, location_type=body_location),
right=nir.BinaryOp(
left=nir.VarAccess(name=reduce_var_name, location_type=body_location),
op=self.REDUCE_OP_TO_BINOP[node.op],
right=self.visit(node.operand, in_neighbor_loop=True, **kwargs),
location_type=body_location,
),
right=reduction_intermediate,
location_type=body_location,
)
],
location_type=body_location,
)
last_block.statements.append(
nir.NeighborLoop(
name=neighbor_loop_name,
neighbors=self.visit(node.neighbors.chain),
body=body,
location_type=node.location_type,
Expand Down
Loading