Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Configurations for Pace to use NDSL with different backend:

- Run: load pre-compiled program and execute, fail if the .so is not present (_no hash check!_) (backend must be `dace:gpu` or `dace:cpu`)

- PACE_FLOAT_PRECISION=64 control the floating point precision throughout the program.
- NDSL_LITERAL_PRECISION=64 controls the floating point precision throughout the program.

Install Pace with different NDSL backend:

Expand Down
29 changes: 23 additions & 6 deletions ndsl/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from enum import Enum
from typing import Literal

import numpy as np

Expand All @@ -16,13 +17,29 @@ class ConstantVersions(Enum):
GEOS = "GEOS" # Constant as defined in GEOS v11.4.2


CONST_VERSION_AS_STR = os.environ.get("PACE_CONSTANTS", "UFS")
def _get_constant_version(
default: Literal["GFDL", "UFS", "GEOS"] = "UFS",
) -> Literal["GFDL", "UFS", "GEOS"]:
if os.getenv("PACE_CONSTANTS", ""):
ndsl_log.warning("PACE_CONSTANTS is deprecated. Use NDSL_CONSTANTS instead.")
if os.getenv("NDSL_CONSTANTS", ""):
ndsl_log.warning(
"PACE_CONSTANTS and NDSL_CONSTANTS were both specified. NDSL_CONSTANTS will take precedence."
)

try:
CONST_VERSION = ConstantVersions[CONST_VERSION_AS_STR]
ndsl_log.info(f"Constant selected: {CONST_VERSION}")
except KeyError as e:
raise RuntimeError(f"Constants {CONST_VERSION_AS_STR} is not implemented, abort.")
constants_as_str = os.getenv("NDSL_CONSTANTS", os.getenv("PACE_CONSTANTS", default))
expected: list[Literal["GFDL", "UFS", "GEOS"]] = ["GFDL", "UFS", "GEOS"]

if constants_as_str not in expected:
raise RuntimeError(
f"Constants '{constants_as_str}' is not implemented, abort. Valid values are {expected}."
)

return constants_as_str # type: ignore


CONST_VERSION = ConstantVersions[_get_constant_version()]
ndsl_log.info(f"Constant selected: {CONST_VERSION}")

#####################
# Common constants
Expand Down
30 changes: 29 additions & 1 deletion ndsl/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Literal precision for both GT4Py & NDSL
import os
import sys
from typing import Literal

from ndsl.comm.mpi import MPI
from ndsl.logging import ndsl_log


gt4py_config_module = "gt4py.cartesian.config"
Expand All @@ -12,7 +14,33 @@
" Please import `ndsl.dsl` or any `ndsl` module "
" before any `gt4py` imports."
)
NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64"))


def _get_literal_precision(default: Literal["32", "64"] = "64") -> Literal["32", "64"]:
if os.getenv("PACE_FLOAT_PRECISION", ""):
ndsl_log.warning(
"PACE_FLOAT_PRECISION is deprecated. Use NDSL_LITERAL_PRECISION instead."
)
if os.getenv("NDSL_LITERAL_PRECISION", ""):
ndsl_log.warning(
"PACE_FLOAT_PRECISION and NDSL_LOGLEVEL were both specified. NDSL_LITERAL_PRECISION will take precedence."
)

precision = os.getenv(
"NDSL_LITERAL_PRECISION", os.getenv("PACE_FLOAT_PRECISION", default)
)

expected: list[Literal["32", "64"]] = ["32", "64"]
if precision in expected:
return precision # type: ignore

ndsl_log.warning(
f"Unexpected literal precision '{precision}', falling back to '{default}'. Valid values are {expected}."
)
return default


NDSL_GLOBAL_PRECISION = int(_get_literal_precision())
os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION)


Expand Down
24 changes: 18 additions & 6 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ndsl.dsl.caches.codepath import FV3CodePath
from ndsl.dsl.gt4py_utils import is_gpu_backend
from ndsl.dsl.typing import get_precision
from ndsl.logging import ndsl_log
from ndsl.optional_imports import cupy as cp


Expand All @@ -24,6 +25,21 @@
DEACTIVATE_DISTRIBUTED_DACE_COMPILE = False


def _debug_dace_orchestration() -> bool:
"""
Debugging Dace orchestration deeper can be done by turning on `syncdebug`.
We control this Dace configuration below with our own override.
"""
if os.getenv("PACE_DACE_DEBUG", ""):
ndsl_log.warning("PACE_DACE_DEBUG is deprecated. Use NDSL_DACE_DEBUG instead.")
if os.getenv("NDSL_DACE_DEBUG", ""):
ndsl_log.warning(
"PACE_DACE_DEBUG and NDSL_DACE_DEBUG were both specified. NDSL_DACE_DEBUG will take precedence."
)

return os.getenv("NDSL_DACE_DEBUG", os.getenv("PACE_DACE_DEBUG", "False")) == "True"


def _is_corner(rank: int, partitioner: Partitioner) -> bool:
if partitioner.tile.on_tile_bottom(rank):
if partitioner.tile.on_tile_left(rank):
Expand Down Expand Up @@ -178,14 +194,10 @@ def __init__(
else:
self._orchestrate = orchestration

# Debugging Dace orchestration deeper can be done by turning on `syncdebug`
# We control this Dace configuration below with our own override
dace_debug_env_var = os.getenv("PACE_DACE_DEBUG", "False") == "True"

# We hijack the optimization level of GT4Py because we don't
# have the configuration at NDSL level, but we do use the GT4Py
# level
# TODO: if GT4PY opt level is funnled via NDSL - use it here
# TODO: if GT4PY opt level is funneled via NDSL - use it here
optimization_level = GT4PY_COMPILE_OPT_LEVEL

# Set the configuration of DaCe to a rigid & tested set of divergence
Expand Down Expand Up @@ -283,7 +295,7 @@ def __init__(

# Enable to debug GPU failures
dace.config.Config.set(
"compiler", "cuda", "syncdebug", value=dace_debug_env_var
"compiler", "cuda", "syncdebug", value=_debug_dace_orchestration()
)

if get_precision() == 32:
Expand Down
5 changes: 3 additions & 2 deletions ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Tuple, TypeAlias, Union, cast

import numpy as np
from gt4py.cartesian import gtscript

from ndsl.dsl import NDSL_GLOBAL_PRECISION


# A Field
Field = gtscript.Field
Expand All @@ -23,7 +24,7 @@


def get_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))
return NDSL_GLOBAL_PRECISION


# We redefine the type as a way to distinguish
Expand Down
33 changes: 27 additions & 6 deletions ndsl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from mpi4py import MPI


LOGLEVEL = os.environ.get("PACE_LOGLEVEL", "INFO").lower()

# Python log levels are hierarchical, therefore setting INFO
# means DEBUG and everything lower will be logged.
AVAILABLE_LOG_LEVELS = {
Expand All @@ -21,12 +19,33 @@
}


def _get_log_level(default: str = "info"):
if os.getenv("PACE_LOGLEVEL", ""):
logging.warning("PACE_LOGLEVEL is deprecated. Use NDSL_LOGLEVEL instead.")
if os.getenv("NDSL_LOGLEVEL", ""):
logging.warning(
"PACE_LOGLEVEL and NDSL_LOGLEVEL were both specified. NDSL_LOGLEVEL will take precedence."
)

loglevel = os.getenv("NDSL_LOGLEVEL", os.getenv("PACE_LOGLEVEL", default)).lower()

if loglevel in AVAILABLE_LOG_LEVELS.keys():
return loglevel

logging.warning(
f"Unknown log level '{loglevel}', falling back to '{default}'. Valid values are: {AVAILABLE_LOG_LEVELS.keys()}."
)
return default


def _ndsl_logger() -> logging.Logger:
log_level = _get_log_level()

name_log = logging.getLogger(__name__)
name_log.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
name_log.setLevel(AVAILABLE_LOG_LEVELS[log_level])

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
handler.setLevel(AVAILABLE_LOG_LEVELS[log_level])
formatter = logging.Formatter(
fmt=(
f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|"
Expand All @@ -40,14 +59,16 @@ def _ndsl_logger() -> logging.Logger:


def _ndsl_logger_on_rank_0() -> logging.Logger:
log_level = _get_log_level()

name_log = logging.getLogger(f"{__name__}_on_rank_0")
name_log.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
name_log.setLevel(AVAILABLE_LOG_LEVELS[log_level])

rank = MPI.COMM_WORLD.Get_rank()

if rank == 0:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(AVAILABLE_LOG_LEVELS[LOGLEVEL])
handler.setLevel(AVAILABLE_LOG_LEVELS[log_level])
formatter = logging.Formatter(
fmt=(
f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|"
Expand Down