Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Configurations for Pace to use NDSL with different backend:

- Run: load pre-compiled program and execute, fail if the .so is not present (_no hash check!_) (backend must be `dace:gpu` or `dace:cpu`)

- PACE_FLOAT_PRECISION=64 control the floating point precision throughout the program.
- NDSL_FLOAT_PRECISION=64 control the floating point precision throughout the program.

Install Pace with different NDSL backend:

Expand Down
29 changes: 23 additions & 6 deletions ndsl/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from enum import Enum
from typing import Literal

import numpy as np

Expand All @@ -16,13 +17,29 @@ class ConstantVersions(Enum):
GEOS = "GEOS" # Constant as defined in GEOS v11.4.2


CONST_VERSION_AS_STR = os.environ.get("PACE_CONSTANTS", "UFS")
def _get_constant_version(
default: Literal["GFDL", "UFS", "GEOS"] = "UFS",
) -> Literal["GFDL", "UFS", "GEOS"]:
if os.getenv("PACE_CONSTANTS", ""):
ndsl_log.warning("PACE_CONSTANTS is deprecated. Use NDSL_CONSTANTS instead.")
if os.getenv("NDSL_CONSTANTS", ""):
ndsl_log.warning(
"PACE_CONSTANTS and NDSL_CONSTANTS were both specified. NDSL_CONSTANTS will take precedence."
)

try:
CONST_VERSION = ConstantVersions[CONST_VERSION_AS_STR]
ndsl_log.info(f"Constant selected: {CONST_VERSION}")
except KeyError as e:
raise RuntimeError(f"Constants {CONST_VERSION_AS_STR} is not implemented, abort.")
constants_as_str = os.getenv("NDSL_CONSTANTS", os.getenv("PACE_CONSTANTS", default))
expected: list[Literal["GFDL", "UFS", "GEOS"]] = ["GFDL", "UFS", "GEOS"]

if constants_as_str not in expected:
raise RuntimeError(
f"Constants '{constants_as_str}' is not implemented, abort. Valid values are {expected}."
)

return constants_as_str # type: ignore


CONST_VERSION = ConstantVersions[_get_constant_version()]
ndsl_log.info(f"Constant selected: {CONST_VERSION}")

#####################
# Common constants
Expand Down
30 changes: 29 additions & 1 deletion ndsl/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Literal precision for both GT4Py & NDSL
import os
import sys
from typing import Literal

from ndsl.comm.mpi import MPI
from ndsl.logging import ndsl_log


gt4py_config_module = "gt4py.cartesian.config"
Expand All @@ -12,7 +14,33 @@
" Please import `ndsl.dsl` or any `ndsl` module "
" before any `gt4py` imports."
)
NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64"))


def _get_literal_precision(default: Literal["32", "64"] = "64") -> Literal["32", "64"]:
if os.getenv("PACE_FLOAT_PRECISION", ""):
ndsl_log.warning(
"PACE_FLOAT_PRECISION is deprecated. Use NDSL_LITERAL_PRECISION instead."
)
if os.getenv("NDSL_LITERAL_PRECISION", ""):
ndsl_log.warning(
"PACE_FLOAT_PRECISION and NDSL_LOGLEVEL were both specified. NDSL_LITERAL_PRECISION will take precedence."
)

precision = os.getenv(
"NDSL_LITERAL_PRECISION", os.getenv("PACE_FLOAT_PRECISION", default)
)

expected: list[Literal["32", "64"]] = ["32", "64"]
if precision in expected:
return precision # type: ignore

ndsl_log.warning(
f"Unexpected literal precision '{precision}', falling back to '{default}'. Valid values are {expected}."
)
return default


NDSL_GLOBAL_PRECISION = int(_get_literal_precision())
os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION)


Expand Down
24 changes: 18 additions & 6 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ndsl.dsl.caches.codepath import FV3CodePath
from ndsl.dsl.gt4py_utils import is_gpu_backend
from ndsl.dsl.typing import get_precision
from ndsl.logging import ndsl_log
from ndsl.optional_imports import cupy as cp


Expand All @@ -24,6 +25,21 @@
DEACTIVATE_DISTRIBUTED_DACE_COMPILE = False


def _debug_dace_orchestration() -> bool:
"""
Debugging Dace orchestration deeper can be done by turning on `syncdebug`.
We control this Dace configuration below with our own override.
"""
if os.getenv("PACE_DACE_DEBUG", ""):
ndsl_log.warning("PACE_DACE_DEBUG is deprecated. Use NDSL_DACE_DEBUG instead.")
if os.getenv("NDSL_DACE_DEBUG", ""):
ndsl_log.warning(
"PACE_DACE_DEBUG and NDSL_DACE_DEBUG were both specified. NDSL_DACE_DEBUG will take precedence."
)

return os.getenv("NDSL_DACE_DEBUG", os.getenv("PACE_DACE_DEBUG", "False")) == "True"


def _is_corner(rank: int, partitioner: Partitioner) -> bool:
if partitioner.tile.on_tile_bottom(rank):
if partitioner.tile.on_tile_left(rank):
Expand Down Expand Up @@ -178,14 +194,10 @@ def __init__(
else:
self._orchestrate = orchestration

# Debugging Dace orchestration deeper can be done by turning on `syncdebug`
# We control this Dace configuration below with our own override
dace_debug_env_var = os.getenv("PACE_DACE_DEBUG", "False") == "True"

# We hijack the optimization level of GT4Py because we don't
# have the configuration at NDSL level, but we do use the GT4Py
# level
# TODO: if GT4PY opt level is funnled via NDSL - use it here
# TODO: if GT4PY opt level is funneled via NDSL - use it here
optimization_level = GT4PY_COMPILE_OPT_LEVEL

# Set the configuration of DaCe to a rigid & tested set of divergence
Expand Down Expand Up @@ -283,7 +295,7 @@ def __init__(

# Enable to debug GPU failures
dace.config.Config.set(
"compiler", "cuda", "syncdebug", value=dace_debug_env_var
"compiler", "cuda", "syncdebug", value=_debug_dace_orchestration()
)

if get_precision() == 32:
Expand Down
5 changes: 3 additions & 2 deletions ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Tuple, TypeAlias, Union, cast

import numpy as np
from gt4py.cartesian import gtscript

from ndsl.dsl import NDSL_GLOBAL_PRECISION


# A Field
Field = gtscript.Field
Expand All @@ -23,7 +24,7 @@


def get_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))
return NDSL_GLOBAL_PRECISION


# We redefine the type as a way to distinguish
Expand Down
33 changes: 27 additions & 6 deletions ndsl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from mpi4py import MPI


LOGLEVEL = os.environ.get("PACE_LOGLEVEL", "INFO").lower()

# Python log levels are hierarchical, therefore setting INFO
# means DEBUG and everything lower will be logged.
AVAILABLE_LOG_LEVELS = {
Expand All @@ -21,12 +19,33 @@
}


def _get_log_level(default: str = "info"):
if os.getenv("PACE_LOGLEVEL", ""):
logging.warning("PACE_LOGLEVEL is deprecated. Use NDSL_LOGLEVEL instead.")
if os.getenv("NDSL_LOGLEVEL", ""):
logging.warning(
"PACE_LOGLEVEL and NDSL_LOGLEVEL were both specified. NDSL_LOGLEVEL will take precedence."
)

loglevel = os.getenv("NDSL_LOGLEVEL", os.getenv("PACE_LOGLEVEL", default)).lower()

if loglevel in AVAILABLE_LOG_LEVELS.keys():
return loglevel

logging.warning(
f"Unknown log level '{loglevel}', falling back to '{default}'. Valid values are: {AVAILABLE_LOG_LEVELS.keys()}."
)
return default


def _ndsl_logger() -> logging.Logger:
log_level = _get_log_level()

name_log = logging.getLogger(__name__)
name_log.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
name_log.setLevel(AVAILABLE_LOG_LEVELS[log_level])

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
handler.setLevel(AVAILABLE_LOG_LEVELS[log_level])
formatter = logging.Formatter(
fmt=(
f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|"
Expand All @@ -40,14 +59,16 @@ def _ndsl_logger() -> logging.Logger:


def _ndsl_logger_on_rank_0() -> logging.Logger:
log_level = _get_log_level()

name_log = logging.getLogger(f"{__name__}_on_rank_0")
name_log.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
name_log.setLevel(AVAILABLE_LOG_LEVELS[log_level])

rank = MPI.COMM_WORLD.Get_rank()

if rank == 0:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
handler.setLevel(AVAILABLE_LOG_LEVELS[log_level])
formatter = logging.Formatter(
fmt=(
f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|"
Expand Down