diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index 672093a55..ffbf7731c 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -24,7 +24,7 @@ from parcels._core.uxgrid import UxGrid from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx, assert_all_field_dims_have_axis from parcels._python import assert_same_function_signature -from parcels._reprs import default_repr +from parcels._reprs import field_repr, vectorfield_repr from parcels._typing import VectorType from parcels.interpolators import ( ZeroInterpolator, @@ -148,6 +148,9 @@ def __init__( if "time" not in self.data.coords: raise ValueError("Field data is missing a 'time' coordinate.") + def __repr__(self): + return field_repr(self) + @property def units(self): return self._units @@ -277,11 +280,7 @@ def __init__( self._vector_interp_method = vector_interp_method def __repr__(self): - return f"""<{type(self).__name__}> - name: {self.name!r} - U: {default_repr(self.U)} - V: {default_repr(self.V)} - W: {default_repr(self.W)}""" + return vectorfield_repr(self) @property def vector_interp_method(self): diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 7f24127a0..6b6c14682 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -18,6 +18,7 @@ from parcels._core.uxgrid import UxGrid from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid from parcels._logger import logger +from parcels._reprs import fieldset_repr from parcels._typing import Mesh from parcels.interpolators import UxPiecewiseConstantFace, UxPiecewiseLinearNode, XConstantField, XLinear @@ -75,6 +76,9 @@ def __getattr__(self, name): else: raise AttributeError(f"FieldSet has no attribute '{name}'") + def __repr__(self): + return fieldset_repr(self) + @property def time_interval(self): """Returns the valid executable time interval of the FieldSet, diff --git a/src/parcels/_core/particle.py b/src/parcels/_core/particle.py index a9a187f30..86ddc5138 100644 --- a/src/parcels/_core/particle.py +++ b/src/parcels/_core/particle.py @@ -9,7 +9,7 @@ from parcels._core.statuscodes import StatusCode from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import TimeInterval -from parcels._reprs import _format_list_items_multiline +from parcels._reprs import particleclass_repr, variable_repr __all__ = ["Particle", "ParticleClass", "Variable"] _TO_WRITE_OPTIONS = [True, False, "once"] @@ -70,7 +70,7 @@ def name(self): return self._name def __repr__(self): - return f"Variable(name={self._name!r}, dtype={self.dtype!r}, initial={self.initial!r}, to_write={self.to_write!r}, attrs={self.attrs!r})" + return variable_repr(self) class ParticleClass: @@ -92,8 +92,7 @@ def __init__(self, variables: list[Variable]): self.variables = variables def __repr__(self): - vars = [repr(v) for v in self.variables] - return f"ParticleClass(variables={_format_list_items_multiline(vars)})" + return particleclass_repr(self) def add_variable(self, variable: Variable | list[Variable]): """Add a new variable to the Particle class. This returns a new Particle class with the added variable(s). diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 778e4275f..e8cb24b5c 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -16,6 +16,7 @@ import parcels from parcels._core.particle import ParticleClass from parcels._core.utils.time import timedelta_to_float +from parcels._reprs import particlefile_repr if TYPE_CHECKING: from parcels._core.particle import Variable @@ -96,12 +97,7 @@ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): # TODO v4: Add check that if create_new_zarrfile is False, the store already exists def __repr__(self) -> str: - return ( - f"{type(self).__name__}(" - f"outputdt={self.outputdt!r}, " - f"chunks={self.chunks!r}, " - f"create_new_zarrfile={self.create_new_zarrfile!r})" - ) + return particlefile_repr(self) def set_metadata(self, parcels_grid_mesh: Literal["spherical", "flat"]): self.metadata.update( diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index 8dea8efaa..72ae049cc 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -7,7 +7,6 @@ import numpy as np import xarray as xr from tqdm import tqdm -from zarr.storage import DirectoryStore from parcels._core.converters import _convert_to_flat_array from parcels._core.kernel import Kernel @@ -21,7 +20,7 @@ ) from parcels._core.warnings import ParticleSetWarning from parcels._logger import logger -from parcels._reprs import particleset_repr +from parcels._reprs import _format_zarr_output_location, particleset_repr __all__ = ["ParticleSet"] @@ -70,7 +69,6 @@ def __init__( **kwargs, ): self._data = None - self._repeat_starttime = None self._kernel = None self.fieldset = fieldset @@ -167,7 +165,7 @@ def __getattr__(self, name): def __getitem__(self, index): """Get a single particle by index.""" - return ParticleSetView(self._data, index=index) + return ParticleSetView(self._data, index=index, ptype=self._ptype) def __setattr__(self, name, value): if name in ["_data"]: @@ -447,7 +445,7 @@ def execute( # Set up pbar if output_file: - logger.info(f"Output files are stored in {_format_output_location(output_file.store)}") + logger.info(f"Output files are stored in {_format_zarr_output_location(output_file.store)}") if verbose_progress: pbar = tqdm(total=end_time - start_time, file=sys.stdout) @@ -592,9 +590,3 @@ def _get_start_time(first_release_time, time_interval, sign_dt, runtime): start_time = first_release_time if not np.isnan(first_release_time) else fieldset_start return start_time - - -def _format_output_location(zarr_obj): - if isinstance(zarr_obj, DirectoryStore): - return zarr_obj.path - return repr(zarr_obj) diff --git a/src/parcels/_core/particlesetview.py b/src/parcels/_core/particlesetview.py index c0ce88c04..b43860460 100644 --- a/src/parcels/_core/particlesetview.py +++ b/src/parcels/_core/particlesetview.py @@ -1,12 +1,15 @@ import numpy as np +from parcels._reprs import particlesetview_repr + class ParticleSetView: """Class to be used in a kernel that links a View of the ParticleSet (on the kernel level) to a ParticleSet.""" - def __init__(self, data, index): + def __init__(self, data, index, ptype): self._data = data self._index = index + self._ptype = ptype def __getattr__(self, name): # Return a proxy that behaves like the underlying numpy array but @@ -25,11 +28,14 @@ def __getattr__(self, name): return self._data[name][self._index] def __setattr__(self, name, value): - if name in ["_data", "_index"]: + if name in ["_data", "_index", "_ptype"]: object.__setattr__(self, name, value) else: self._data[name][self._index] = value + def __repr__(self): + return particlesetview_repr(self) + def __getitem__(self, index): # normalize single-element tuple indexing (e.g., (inds,)) if isinstance(index, tuple) and len(index) == 1: @@ -50,7 +56,7 @@ def __getitem__(self, index): raise ValueError( f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}" ) - return ParticleSetView(self._data, new_index) + return ParticleSetView(self._data, new_index, self._ptype) # Integer array/list, slice or single integer relative to the local view # (boolean masks were handled above). Normalize and map to global @@ -65,12 +71,12 @@ def __getitem__(self, index): base_arr = np.asarray(base) sel = base_arr[idx] new_index[sel] = True - return ParticleSetView(self._data, new_index) + return ParticleSetView(self._data, new_index, self._ptype) # Fallback: try to assign directly (preserves previous behaviour for other index types) try: new_index[base] = index - return ParticleSetView(self._data, new_index) + return ParticleSetView(self._data, new_index, self._ptype) except Exception as e: raise TypeError(f"Unsupported index type for ParticleSetView.__getitem__: {type(index)!r}") from e diff --git a/src/parcels/_core/utils/time.py b/src/parcels/_core/utils/time.py index fe62813ef..72843815e 100644 --- a/src/parcels/_core/utils/time.py +++ b/src/parcels/_core/utils/time.py @@ -6,6 +6,8 @@ import cftime import numpy as np +from parcels._reprs import timeinterval_repr + if TYPE_CHECKING: from parcels._typing import TimeLike @@ -61,7 +63,7 @@ def is_all_time_in_interval(self, time: float): return (0 <= item).all() and (item <= self.time_length_as_flt).all() def __repr__(self) -> str: - return f"TimeInterval(left={self.left!r}, right={self.right!r})" + return timeinterval_repr(self) def __eq__(self, other: object) -> bool: if not isinstance(other, TimeInterval): diff --git a/src/parcels/_core/xgrid.py b/src/parcels/_core/xgrid.py index 86805a60f..2de3de998 100644 --- a/src/parcels/_core/xgrid.py +++ b/src/parcels/_core/xgrid.py @@ -9,6 +9,7 @@ from parcels._core.basegrid import BaseGrid from parcels._core.index_search import _search_1d_array, _search_indices_curvilinear_2d +from parcels._reprs import xgrid_repr from parcels._typing import assert_valid_mesh _XGRID_AXES = Literal["X", "Y", "Z"] @@ -135,6 +136,9 @@ def from_dataset(cls, ds: xr.Dataset, mesh, xgcm_kwargs=None): grid = xgcm.Grid(ds, **xgcm_kwargs) return cls(grid, mesh=mesh) + def __repr__(self): + return xgrid_repr(self) + @property def axes(self) -> list[_XGRID_AXES]: return _get_xgrid_axes(self.xgcm_grid) diff --git a/src/parcels/_reprs.py b/src/parcels/_reprs.py index 6ea42992a..07f58b153 100644 --- a/src/parcels/_reprs.py +++ b/src/parcels/_reprs.py @@ -5,22 +5,142 @@ import textwrap from typing import TYPE_CHECKING, Any +import numpy as np +import xarray as xr +from zarr.storage import DirectoryStore + if TYPE_CHECKING: from parcels import Field, FieldSet, ParticleSet -def field_repr(field: Field) -> str: # TODO v4: Rework or remove entirely +def fieldset_repr(fieldset: FieldSet) -> str: + """Return a pretty repr for FieldSet""" + fields = [f for f in fieldset.fields.values() if getattr(f.__class__, "__name__", "") == "Field"] + vfields = [f for f in fieldset.fields.values() if getattr(f.__class__, "__name__", "") == "VectorField"] + + fields_repr = "\n".join([repr(f) for f in fields]) + vfields_repr = "\n".join([vectorfield_repr(vf, from_fieldset_repr=True) for vf in vfields]) + + out = f"""<{type(fieldset).__name__}> + fields: +{textwrap.indent(fields_repr, 8 * " ")} + vectorfields: +{textwrap.indent(vfields_repr, 8 * " ")} +""" + return textwrap.dedent(out).strip() + + +# TODO add land_value here after HG #2451 is merged +def field_repr(field: Field, level: int = 0) -> str: """Return a pretty repr for Field""" - out = f"""<{type(field).__name__}> - name : {field.name!r} - data : {field.data!r} - extrapolate time: {field.allow_time_extrapolation!r} + with xr.set_options(display_expand_data=False): + out = f"""<{type(field).__name__} {field.name!r}> + Parcels attributes: + name : {field.name!r} + interp_method : {field.interp_method!r} + time_interval : {field.time_interval!r} + units : {field.units!r} + igrid : {field.igrid!r} + DataArray: +{textwrap.indent(repr(field.data), 8 * " ")} +{textwrap.indent(repr(field.grid), 4 * " ")} +""" + return textwrap.indent(out, " " * level * 4).strip() + + +def vectorfield_repr(fieldset: FieldSet, from_fieldset_repr=False) -> str: + """Return a pretty repr for VectorField""" + out = f"""<{type(fieldset).__name__} {fieldset.name!r}> + Parcels attributes: + name : {fieldset.name!r} + vector_interp_method : {fieldset.vector_interp_method!r} + vector_type : {fieldset.vector_type!r} + {field_repr(fieldset.U, level=1) if not from_fieldset_repr else ""} + {field_repr(fieldset.V, level=1) if not from_fieldset_repr else ""} + {field_repr(fieldset.W, level=1) if not from_fieldset_repr and fieldset.W else ""}""" + return out + + +def xgrid_repr(grid: Any) -> str: + """Return a pretty repr for Grid""" + out = f"""<{type(grid).__name__}> + Parcels attributes: + mesh : {grid._mesh} + spatialhash : {grid._spatialhash} + xgcm Grid: +{textwrap.indent(repr(grid.xgcm_grid), 8 * " ")} +""" + return textwrap.dedent(out).strip() + + +def particleset_repr(pset: ParticleSet) -> str: + """Return a pretty repr for ParticleSet""" + if len(pset) < 10: + particles = [repr(p) for p in pset] + else: + particles = [repr(pset[i]) for i in range(7)] + ["..."] + [repr(pset[-1])] + + out = f"""<{type(pset).__name__}> + Number of particles: {len(pset)} + Particles: +{_format_list_items_multiline(particles, level=2, with_brackets=False)} + Pclass: +{textwrap.indent(repr(pset._ptype), 8 * " ")} +""" + return textwrap.dedent(out).strip() + + +def particlesetview_repr(pview: Any) -> str: + """Return a pretty repr for ParticleSetView""" + time_string = "not_yet_set" if pview.time is None or np.isnan(pview.time) else f"{pview.time:f}" + out = f"P[{pview.trajectory}]: time={time_string}, z={pview.z:f}, lat={pview.lat:f}, lon={pview.lon:f}" + vars = [v.name for v in pview._ptype.variables if v.to_write is True and v.name not in ["lon", "lat", "z", "time"]] + for var in vars: + out += f", {var}={getattr(pview, var):f}" + + return textwrap.dedent(out).strip() + + +def particleclass_repr(pclass: Any) -> str: + """Return a pretty repr for ParticleClass""" + vars = [repr(v) for v in pclass.variables] + out = f""" +{_format_list_items_multiline(vars, level=1, with_brackets=False)} +""" + return textwrap.dedent(out).strip() + + +def variable_repr(var: Any) -> str: + """Return a pretty repr for Variable""" + return f"Variable(name={var._name!r}, dtype={var.dtype!r}, initial={var.initial!r}, to_write={var.to_write!r}, attrs={var.attrs!r})" + + +def timeinterval_repr(ti: Any) -> str: + """Return a pretty repr for TimeInterval""" + return f"TimeInterval(left={ti.left!r}, right={ti.right!r})" + + +def particlefile_repr(pfile: Any) -> str: + """Return a pretty repr for ParticleFile""" + out = f"""<{type(pfile).__name__}> + store : {_format_zarr_output_location(pfile.store)} + outputdt : {pfile.outputdt!r} + chunks : {pfile.chunks!r} + create_new_zarrfile : {pfile.create_new_zarrfile!r} + metadata : +{_format_list_items_multiline(pfile.metadata, level=2, with_brackets=False)} """ return textwrap.dedent(out).strip() -def _format_list_items_multiline(items: list[str], level: int = 1) -> str: - """Given a list of strings, formats them across multiple lines. +def default_repr(obj: Any): + if is_builtin_object(obj): + return repr(obj) + return object.__repr__(obj) + + +def _format_list_items_multiline(items: list[str] | dict, level: int = 1, with_brackets: bool = True) -> str: + """Given a list of strings or a dict, formats them across multiple lines. Uses indentation levels of 4 spaces provided by ``level``. @@ -41,42 +161,22 @@ def _format_list_items_multiline(items: list[str], level: int = 1) -> str: indentation_str = level * 4 * " " indentation_str_end = (level - 1) * 4 * " " - items_str = ",\n".join([textwrap.indent(i, indentation_str) for i in items]) - return f"[\n{items_str}\n{indentation_str_end}]" - - -def particleset_repr(pset: ParticleSet) -> str: - """Return a pretty repr for ParticleSet""" - if len(pset) < 10: - particles = [repr(p) for p in pset] + if isinstance(items, dict): + entries = [f"{k!s}: {v!s}" for k, v in items.items()] else: - particles = [repr(pset[i]) for i in range(7)] + ["..."] + entries = [i if isinstance(i, str) else repr(i) for i in items] - out = f"""<{type(pset).__name__}> - fieldset : -{textwrap.indent(repr(pset.fieldset), " " * 8)} - ptype : {pset._ptype} - # particles: {len(pset)} - particles : {_format_list_items_multiline(particles, level=2)} -""" - return textwrap.dedent(out).strip() - - -def fieldset_repr(fieldset: FieldSet) -> str: # TODO v4: Rework or remove entirely - """Return a pretty repr for FieldSet""" - fields_repr = "\n".join([repr(f) for f in fieldset.fields.values()]) - - out = f"""<{type(fieldset).__name__}> - fields: -{textwrap.indent(fields_repr, 8 * " ")} -""" - return textwrap.dedent(out).strip() + if with_brackets: + items_str = ",\n".join([textwrap.indent(e, indentation_str) for e in entries]) + return f"[\n{items_str}\n{indentation_str_end}]" + else: + return "\n".join([textwrap.indent(e, indentation_str) for e in entries]) -def default_repr(obj: Any): - if is_builtin_object(obj): - return repr(obj) - return object.__repr__(obj) +def _format_zarr_output_location(zarr_obj): + if isinstance(zarr_obj, DirectoryStore): + return zarr_obj.path + return repr(zarr_obj) def is_builtin_object(obj): diff --git a/tests/test_particle.py b/tests/test_particle.py index 5daa2c8a3..4fc02ee9c 100644 --- a/tests/test_particle.py +++ b/tests/test_particle.py @@ -80,11 +80,9 @@ def test_particleclass_invalid_vars(): Variable("varc", dtype=np.float32, to_write=True), ] ), - """ParticleClass(variables=[ - Variable(name='vara', dtype=dtype('float32'), initial=0, to_write=True, attrs={}), - Variable(name='varb', dtype=dtype('float32'), initial=0, to_write=False, attrs={}), - Variable(name='varc', dtype=dtype('float32'), initial=0, to_write=True, attrs={}) -])""", + """Variable(name='vara', dtype=dtype('float32'), initial=0, to_write=True, attrs={}) +Variable(name='varb', dtype=dtype('float32'), initial=0, to_write=False, attrs={}) +Variable(name='varc', dtype=dtype('float32'), initial=0, to_write=True, attrs={})""", ), ], )