From fa831dd73e854893e53d326f7463c3574ccffcc0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 4 Nov 2025 12:38:55 -0600 Subject: [PATCH] EvaluationMapper: preserve Subscript --- pymbolic/mapper/evaluator.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index c034ab15..a914a1d6 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -11,8 +11,6 @@ from __future__ import annotations -from pytools import ndindex - __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" @@ -41,6 +39,8 @@ from typing_extensions import override +from pytools import ndindex + from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper, ResultT @@ -53,7 +53,7 @@ import pymbolic.primitives as p from pymbolic.geometric_algebra import MultiVector from pymbolic.rational import Rational - from pymbolic.typing import Expression + from pymbolic.typing import ArithmeticExpression, Expression class UnknownVariableError(Exception): @@ -116,7 +116,14 @@ def map_call_with_kwargs(self, expr: p.CallWithKwargs, /) -> ResultT: @override def map_subscript(self, expr: p.Subscript, /) -> ResultT: - return self.rec(expr.aggregate)[self.rec(expr.index)] + agg = self.rec(expr.aggregate) + index = self.rec(expr.index) + from pymbolic.primitives import EmptyOK, ExpressionNode + if isinstance(agg, ExpressionNode): + return cast("ResultT", + agg[EmptyOK(cast("ArithmeticExpression", index))]) + else: + return agg[index] # pyright: ignore[reportIndexIssue] @override def map_lookup(self, expr: p.Lookup, /) -> ResultT: