Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions qupulse/program/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from qupulse import ChannelID
from qupulse.utils.types import SingletonABCMeta, frozendict, DocStringABCMeta
from qupulse.expressions import ExpressionScalar
from qupulse.program.values import DynamicLinearValue


_TrafoValue = Union[Real, ExpressionScalar]
_TrafoValue = Union[Real, ExpressionScalar, DynamicLinearValue]


__all__ = ['Transformation', 'IdentityTransformation', 'LinearTransformation', 'ScalingTransformation',
Expand Down Expand Up @@ -63,7 +64,11 @@ def is_constant_invariant(self):

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return frozenset()


@property
def contains_sweepval(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be a great churn to rename this to contains_dynamic_value and make it a method instead of a property?

Copy link
Collaborator Author

@Nomos11 Nomos11 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be no problem.

what is the benefit here of a method over property?

raise NotImplementedError()


class IdentityTransformation(Transformation, metaclass=SingletonABCMeta):
__slots__ = ()
Expand Down Expand Up @@ -275,7 +280,7 @@ def __init__(self, offsets: Mapping[ChannelID, _TrafoValue]):

def __call__(self, time: Union[np.ndarray, float],
data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]:
offsets = _instantiate_expression_dict(time, self._offsets)
offsets = _instantiate_expression_dict(time, self._offsets, default_sweepval=0.)
return {channel: channel_values + offsets[channel] if channel in offsets else channel_values
for channel, channel_values in data.items()}

Expand Down Expand Up @@ -308,7 +313,11 @@ def is_constant_invariant(self):

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return _get_constant_output_channels(self._offsets, input_channels)


@property
def contains_sweepval(self) -> bool:
return any(isinstance(o,DynamicLinearValue) for o in self._offsets.values())


class ScalingTransformation(Transformation):
__slots__ = ('_factors',)
Expand All @@ -319,7 +328,7 @@ def __init__(self, factors: Mapping[ChannelID, _TrafoValue]):

def __call__(self, time: Union[np.ndarray, float],
data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]:
factors = _instantiate_expression_dict(time, self._factors)
factors = _instantiate_expression_dict(time, self._factors, default_sweepval=1.)
return {channel: channel_values * factors[channel] if channel in factors else channel_values
for channel, channel_values in data.items()}

Expand Down Expand Up @@ -352,7 +361,11 @@ def is_constant_invariant(self):

def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return _get_constant_output_channels(self._factors, input_channels)


@property
def contains_sweepval(self) -> bool:
return any(isinstance(o,DynamicLinearValue) for o in self._factors.values())


try:
if TYPE_CHECKING:
Expand Down Expand Up @@ -437,7 +450,11 @@ def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -
output_channels.add(ch)

return output_channels


@property
def contains_sweepval(self) -> bool:
return any(isinstance(o,DynamicLinearValue) for o in self._channels.values())


def chain_transformations(*transformations: Transformation) -> Transformation:
parsed_transformations = []
Expand All @@ -456,12 +473,20 @@ def chain_transformations(*transformations: Transformation) -> Transformation:
return ChainedTransformation(*parsed_transformations)


def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue]) -> Mapping[str, Union[Real, np.ndarray]]:
def _instantiate_expression_dict(time,
expressions: Mapping[str, _TrafoValue],
default_sweepval: Real,
) -> Mapping[str, Union[Real, np.ndarray]]:
scope = {'t': time}
modified_expressions = {}
for name, value in expressions.items():
if hasattr(value, 'evaluate_in_scope'):
modified_expressions[name] = value.evaluate_in_scope(scope)
if isinstance(value, DynamicLinearValue):
# it is assumed that swept parameters will be handled by the ProgramBuilder accordingly
# such that here only an "identity" trafo is to be applied and the
# trafos are set in the program internally.
modified_expressions[name] = default_sweepval
if modified_expressions:
return {**expressions, **modified_expressions}
else:
Expand Down
27 changes: 26 additions & 1 deletion tests/_program/transformation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from qupulse.program.transformation import LinearTransformation, Transformation, IdentityTransformation,\
ChainedTransformation, ParallelChannelTransformation, chain_transformations, OffsetTransformation,\
ScalingTransformation
from qupulse.program.values import DynamicLinearValue


class TransformationStub(Transformation):
Expand Down Expand Up @@ -179,7 +180,11 @@ class IdentityTransformationTests(unittest.TestCase):
def test_compare_key(self):
with self.assertWarns(DeprecationWarning):
self.assertIsNone(IdentityTransformation().compare_key)


def test_sweepval(self):
with self.assertRaises(NotImplementedError):
IdentityTransformation().contains_sweepval

def test_singleton(self):
self.assertIs(IdentityTransformation(), IdentityTransformation())

Expand Down Expand Up @@ -489,6 +494,17 @@ def test_time_dependence(self):
}, transformed)


def test_sweepval(self):
channels = {'X': 2, 'Y': DynamicLinearValue(0.1, {'a':0.02})}
trafo = OffsetTransformation(channels)
self.assertEqual(trafo.contains_sweepval, True)

channels = {'X': 2, 'Y': 2}
trafo = OffsetTransformation(channels)
self.assertEqual(trafo.contains_sweepval, False)



class TestScalingTransformation(unittest.TestCase):
def setUp(self) -> None:
self.constant_scales = {'A': 1.5, 'B': 1.2}
Expand Down Expand Up @@ -561,3 +577,12 @@ def test_time_dependence(self):
'Z': np.tan(t) * np.exp(t),
'K': values['K']
}, transformed)

def test_sweepval(self):
channels = {'X': 2, 'Y': DynamicLinearValue(0.1, {'a':0.02})}
trafo = ScalingTransformation(channels)
self.assertEqual(trafo.contains_sweepval, True)

channels = {'X': 2, 'Y': 2}
trafo = ScalingTransformation(channels)
self.assertEqual(trafo.contains_sweepval, False)
9 changes: 5 additions & 4 deletions tests/pulses/arithmetic_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,11 @@ def test_internal_create_program(self):
to_single_waveform = {'something_else'}
program_builder = mock.Mock()

expected_transformation = mock.create_autospec(IdentityTransformation,instance=True)
with self.assertWarns(DeprecationWarning):
expected_transformation = mock.Mock(spec=IdentityTransformation())
IdentityTransformation().compare_key

with self.assertWarns(DeprecationWarning):
inner_trafo = mock.Mock(spec=IdentityTransformation())
inner_trafo = mock.create_autospec(spec=IdentityTransformation,instance=True)
inner_trafo.chain.return_value = expected_transformation

with mock.patch.object(rhs, '_create_program') as inner_create_program:
Expand Down Expand Up @@ -595,10 +595,11 @@ def test_build_waveform(self):
channel_mapping = dict(a='u', b='v')

inner_wf = DummyWaveform(duration=6, defined_channels={'a'})
trafo = mock.create_autospec(IdentityTransformation,instance=True)
with self.assertWarns(DeprecationWarning):
# mock will inspect alsod eprecated attributes
# TODO: remove assert as soon as attribute is removed
trafo = mock.Mock(spec=IdentityTransformation())
IdentityTransformation().compare_key

arith = ArithmeticPulseTemplate(pt, '-', 6)

Expand Down
Loading