Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/gt4py/next/otf/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from gt4py.eve import utils
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import code_specs, definitions
from gt4py.next.otf.binding import interface

Expand All @@ -36,18 +35,13 @@ def compilation_hash(program_def: definitions.CompilableProgramDef) -> int:

def fingerprint_compilable_program(program_def: definitions.CompilableProgramDef) -> str:
"""
Generates a unique hash string for a stencil source program representing
the program, sorted offset_provider, and column_axis.
Generates a unique hash string for a compilable program representing
the program IR and all compile-time arguments.
"""
program: itir.Program = program_def.data
offset_provider: common.OffsetProvider = program_def.args.offset_provider
column_axis: Optional[common.Dimension] = program_def.args.column_axis

program_hash = utils.content_hash(
(
program.fingerprint(),
sorted(offset_provider.items(), key=lambda el: el[0]),
column_axis,
program_def.data.fingerprint(),
program_def.args,
)
)
Comment on lines 41 to 46

Expand Down
38 changes: 38 additions & 0 deletions tests/next_tests/unit_tests/iterator_tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pytest

from gt4py.next.iterator import ir
from gt4py.next.otf import arguments, definitions, stages
from gt4py.next.type_system import type_specifications as ts
Comment on lines +12 to +13
from gt4py import eve


Expand Down Expand Up @@ -50,3 +52,39 @@ def node_maker(fun: str, filename: str):
node3 = node_maker("f3", "loc1")
assert node1.fingerprint() == node2.fingerprint()
assert node1.fingerprint() != node3.fingerprint()


def test_different_precisions():
program = ir.Program(
id="test_program",
function_definitions=[],
params=[ir.Sym(id="arg")],
declarations=[],
body=[],
)

compilable_single = definitions.CompilableProgramDef(
data=program,
args=arguments.CompileTimeArgs(
args=(ts.ScalarType(kind=ts.ScalarKind.FLOAT32),),
kwargs={},
offset_provider={},
column_axis=None,
argument_descriptor_contexts={},
),
)
compilable_double = definitions.CompilableProgramDef(
data=program,
args=arguments.CompileTimeArgs(
args=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),),
kwargs={},
offset_provider={},
column_axis=None,
argument_descriptor_contexts={},
),
)

hash_single = stages.fingerprint_compilable_program(compilable_single)
hash_double = stages.fingerprint_compilable_program(compilable_double)

assert hash_single != hash_double
Loading