Skip to content

Commit

Permalink
[1.9] AssetKey path as a tuple (dagster-io#25240)
Browse files Browse the repository at this point in the history
Storing `path` as a tuple and avoiding custom `__eq__` and `__hash__`
functions results in a substantial performance improvement for
operations like building up a large global asset graph.

## How I Tested These Changes
For this target large graph, function calls decreased by ~70% and
execution time decreased by ~50%

Before:
```
Profiling asset graph...
         353461 function calls in 0.083 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   119166    0.021    0.000    0.027    0.000 asset_key.py:56(__hash__)
        1    0.012    0.012    0.080    0.080 remote_asset_graph.py:303(_build)
        1    0.008    0.008    0.018    0.018 remote_asset_graph.py:355(<dictcomp>)
   119166    0.006    0.000    0.006    0.000 {built-in method builtins.hash}
        1    0.005    0.005    0.012    0.012 remote_asset_graph.py:502(_build_execution_set_index)
     7790    0.003    0.000    0.006    0.000 external_data.py:1313(key)
        2    0.003    0.002    0.006    0.003 remote_asset_graph.py:481(_warn_on_duplicates_within_subset)
    17250    0.003    0.000    0.007    0.000 {method 'add' of 'set' objects}
        1    0.002    0.002    0.083    0.083 remote_asset_graph.py:276(from_workspace_snapshot)
     5012    0.002    0.000    0.004    0.000 remote_asset_graph.py:51(__init__)
     9598    0.002    0.000    0.002    0.000 asset_key.py:59(__eq__)
        1    0.002    0.002    0.003    0.003 remote_asset_graph.py:330(<dictcomp>)
     5012    0.002    0.000    0.002    0.000 remote_asset_graph.py:64(<listcomp>)
        1    0.001    0.001    0.003    0.003 remote_asset_graph.py:329(<dictcomp>)
     7790    0.001    0.000    0.002    0.000 <string>:1(<lambda>)
        1    0.001    0.001    0.002    0.002 remote_asset_graph.py:350(<dictcomp>)
        1    0.001    0.001    0.008    0.008 remote_asset_graph.py:455(_warn_on_duplicate_nodes)
     7798    0.001    0.000    0.001    0.000 {built-in method __new__ of type object at 0x10320bcb0}
        1    0.001    0.001    0.002    0.002 remote_asset_graph.py:328(<setcomp>)
    21747    0.001    0.000    0.001    0.000 {method 'append' of 'list' objects}
    19842    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
        1    0.001    0.001    0.001    0.001 remote_asset_graph.py:256(<dictcomp>)
        2    0.001    0.000    0.001    0.000 remote_asset_graph.py:489(<dictcomp>)
...
```

After:
```
Profiling asset graph...
         105531 function calls in 0.043 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.009    0.009    0.040    0.040 remote_asset_graph.py:303(_build)
        1    0.006    0.006    0.010    0.010 remote_asset_graph.py:355(<dictcomp>)
        1    0.004    0.004    0.006    0.006 remote_asset_graph.py:502(_build_execution_set_index)
        2    0.003    0.001    0.004    0.002 remote_asset_graph.py:481(_warn_on_duplicates_within_subset)
     7790    0.003    0.000    0.004    0.000 external_data.py:1313(key)
     5012    0.002    0.000    0.002    0.000 remote_asset_graph.py:64(<listcomp>)
        1    0.002    0.002    0.043    0.043 remote_asset_graph.py:276(from_workspace_snapshot)
     5012    0.002    0.000    0.004    0.000 remote_asset_graph.py:51(__init__)
    17250    0.002    0.000    0.002    0.000 {method 'add' of 'set' objects}
        1    0.001    0.001    0.001    0.001 remote_asset_graph.py:329(<dictcomp>)
        1    0.001    0.001    0.005    0.005 remote_asset_graph.py:455(_warn_on_duplicate_nodes)
        1    0.001    0.001    0.001    0.001 remote_asset_graph.py:330(<dictcomp>)
     7790    0.001    0.000    0.002    0.000 <string>:1(<lambda>)
        1    0.001    0.001    0.001    0.001 remote_asset_graph.py:350(<dictcomp>)
     7798    0.001    0.000    0.001    0.000 {built-in method __new__ of type object at 0x1050b3cb0}
    21747    0.001    0.000    0.001    0.000 {method 'append' of 'list' objects}
    19842    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
        1    0.001    0.001    0.001    0.001 remote_asset_graph.py:256(<dictcomp>)
        2    0.001    0.000    0.001    0.000 remote_asset_graph.py:489(<dictcomp>)
        1    0.001    0.001    0.001    0.001 remote_asset_graph.py:328(<setcomp>)
...
```

## Changelog

[breaking] `AssetKey` can no longer be iterated over or indexed in to.
This behavior was never an intended access pattern and in all observed
cases was a mistake.
  • Loading branch information
alangenfeld authored and Grzyblon committed Oct 26, 2024
1 parent a58514a commit e97c61b
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ To add the `order_count_chart` asset:
```python file=integrations/dbt/tutorial/downstream_assets/assets.py startafter=start_downstream_asset endbefore=end_downstream_asset
@asset(
compute_kind="python",
deps=get_asset_key_for_model([jaffle_shop_dbt_assets], "customers"),
deps=[get_asset_key_for_model([jaffle_shop_dbt_assets], "customers")],
)
def order_count_chart(context: AssetExecutionContext):
# read the contents of the customers table into a Pandas DataFrame
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def jaffle_shop_dbt_assets(context: AssetExecutionContext, dbt: DbtCliResource):
# start_downstream_asset
@asset(
compute_kind="python",
deps=get_asset_key_for_model([jaffle_shop_dbt_assets], "customers"),
deps=[get_asset_key_for_model([jaffle_shop_dbt_assets], "customers")],
)
def order_count_chart(context: AssetExecutionContext):
# read the contents of the customers table into a Pandas DataFrame
Expand Down
52 changes: 37 additions & 15 deletions python_modules/dagster/dagster/_core/definitions/asset_key.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import re
from functools import cached_property
from typing import TYPE_CHECKING, Any, List, Mapping, NamedTuple, Optional, Sequence, TypeVar, Union

import dagster._check as check
import dagster._seven as seven
from dagster._annotations import PublicAttr
from dagster._annotations import PublicAttr, public
from dagster._core.errors import DagsterInvariantViolationError
from dagster._record import IHaveNew, record_custom
from dagster._serdes import whitelist_for_serdes

ASSET_KEY_SPLIT_REGEX = re.compile("[^a-zA-Z0-9_]")
Expand All @@ -19,7 +22,11 @@ def parse_asset_key_string(s: str) -> Sequence[str]:


@whitelist_for_serdes
class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])):
@record_custom(
checked=False,
field_to_new_mapping={"parts": "path"},
)
class AssetKey(IHaveNew):
"""Object representing the structure of an asset key. Takes in a sanitized string, list of
strings, or tuple of strings.
Expand All @@ -39,29 +46,31 @@ class AssetKey(NamedTuple("_AssetKey", [("path", PublicAttr[Sequence[str]])])):
strings represent the hierarchical structure of the asset_key.
"""

def __new__(cls, path: Union[str, Sequence[str]]):
# Originally AssetKey contained "path" as a list. In order to change to using a tuple, we now have
parts: Sequence[str] # with path available as a property defined below still returning a list.

def __new__(
cls,
path: Union[str, Sequence[str]],
):
if isinstance(path, str):
path = [path]
parts = (path,)
else:
path = list(check.sequence_param(path, "path", of_type=str))
parts = tuple(check.sequence_param(path, "path", of_type=str))

return super().__new__(cls, parts=parts)

return super(AssetKey, cls).__new__(cls, path=path)
@public
@cached_property
def path(self) -> Sequence[str]:
return list(self.parts)

def __str__(self):
return f"AssetKey({self.path})"

def __repr__(self):
return f"AssetKey({self.path})"

def __hash__(self):
return hash(tuple(self.path))

def __eq__(self, other):
if other.__class__ is not self.__class__:
return False

return self.path == other.path

def to_string(self) -> str:
"""E.g. '["first_component", "second_component"]'."""
return self.to_db_string()
Expand Down Expand Up @@ -150,6 +159,19 @@ def with_prefix(self, prefix: "CoercibleToAssetKeyPrefix") -> "AssetKey":
prefix = key_prefix_from_coercible(prefix)
return AssetKey(list(prefix) + list(self.path))

def __iter__(self):
raise DagsterInvariantViolationError(
"You have attempted to iterate a single AssetKey object. "
"As of 1.9, this behavior is disallowed because it is likely unintentional and a bug."
)

def __getitem__(self, _):
raise DagsterInvariantViolationError(
"You have attempted to index directly in to the AssetKey object. "
"As of 1.9, this behavior is disallowed because it is likely unintentional and a bug. "
"Use asset_key.path instead to access the list of key components."
)


CoercibleToAssetKey = Union[AssetKey, str, Sequence[str]]
CoercibleToAssetKeyPrefix = Union[str, Sequence[str]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def resolve_similar_asset_names(
# If the asset key provided has no prefix and the upstream key has
# the same name but a prefix of any length
no_prefix_but_is_match_with_prefix = (
len(target_asset_key) == 1 and asset_key.path[-1] == target_asset_key.path[-1]
len(target_asset_key.path) == 1 and asset_key.path[-1] == target_asset_key.path[-1]
)

matches_slashes_turned_to_prefix_gaps = asset_key.path == target_asset_key_split
Expand Down
10 changes: 7 additions & 3 deletions python_modules/dagster/dagster/_record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,15 @@ def _namedtuple_record_transform(
# the default namedtuple record cannot handle subclasses that have different fields from their
# parents if both are records
base.__repr__ = _repr
nt_iter = base.__iter__
base.__iter__ = _banned_iter
base.__getitem__ = _banned_idx

# these will override an implementation on the class if it exists
new_class_dict = {
**{n: getattr(base, n) for n in field_set.keys()},
"_fields": base._fields,
"__iter__": _banned_iter,
"__getitem__": _banned_idx,
"__hidden_iter__": base.__iter__,
"__hidden_iter__": nt_iter,
"__hidden_replace__": base._replace,
_RECORD_MARKER_FIELD: _RECORD_MARKER_VALUE,
_RECORD_ANNOTATIONS_FIELD: field_set,
Expand Down Expand Up @@ -322,6 +323,9 @@ class IHaveNew:

def __new__(cls, **kwargs) -> Self: ...

# let type checker know these objects are sortable (by way of being a namedtuple)
def __lt__(self, other) -> bool: ...


def is_record(obj) -> bool:
"""Whether or not this object was produced by a record decorator."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2402,3 +2402,24 @@ def op1():
AssetsDefinition(
node_def=op1, keys_by_output_name={"out1": AssetKey("a"), "out2": AssetKey("a")}
)


def test_iterate_over_single_key():
key = AssetKey("ouch")
with pytest.raises(
DagsterInvariantViolationError,
match="You have attempted to iterate a single AssetKey object. "
"As of 1.9, this behavior is disallowed because it is likely unintentional and a bug.",
):
[_ for _ in key]


def test_index_in_to_key():
key = AssetKey("ouch")
with pytest.raises(
DagsterInvariantViolationError,
match="You have attempted to index directly in to the AssetKey object. "
"As of 1.9, this behavior is disallowed because it is likely unintentional and a bug. "
"Use asset_key.path instead to access the list of key components.",
):
key[0][0]
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,18 @@ def run(


def run_request(
asset_keys: Sequence[CoercibleToAssetKey],
asset_keys: Union[AssetKey, Sequence[CoercibleToAssetKey]],
partition_key: Optional[str] = None,
fail_keys: Optional[Sequence[str]] = None,
tags: Optional[Mapping[str, str]] = None,
) -> RunRequest:
if isinstance(asset_keys, AssetKey):
asset_selection = [asset_keys]
else:
asset_selection = [AssetKey.from_coercible(key) for key in asset_keys]

return RunRequest(
asset_selection=[AssetKey.from_coercible(key) for key in asset_keys],
asset_selection=asset_selection,
partition_key=partition_key,
tags={**(tags or {}), **({FAIL_TAG: json.dumps(fail_keys)} if fail_keys else {})},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _download_artifact(self, context: InputContext):

artifact_name = parameters.get("name")
if artifact_name is None:
artifact_name = context.asset_key[0][0] # name of asset
artifact_name = context.asset_key.path[0] # name of asset

partitions = [
(key, f"{artifact_name}.{ str(key).replace('|', '-')}")
Expand Down

0 comments on commit e97c61b

Please sign in to comment.