diff --git a/qupulse/program/transformation.py b/qupulse/program/transformation.py index 6350cf65a..607c8dc42 100644 --- a/qupulse/program/transformation.py +++ b/qupulse/program/transformation.py @@ -14,9 +14,10 @@ from qupulse import ChannelID from qupulse.utils.types import SingletonABCMeta, frozendict, DocStringABCMeta from qupulse.expressions import ExpressionScalar +from qupulse.program.values import DynamicLinearValue -_TrafoValue = Union[Real, ExpressionScalar] +_TrafoValue = Union[Real, ExpressionScalar, DynamicLinearValue] __all__ = ['Transformation', 'IdentityTransformation', 'LinearTransformation', 'ScalingTransformation', @@ -63,7 +64,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return frozenset() - + + def contains_dynamic_value(self) -> bool: + raise NotImplementedError() + class IdentityTransformation(Transformation, metaclass=SingletonABCMeta): __slots__ = () @@ -275,7 +279,7 @@ def __init__(self, offsets: Mapping[ChannelID, _TrafoValue]): def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - offsets = _instantiate_expression_dict(time, self._offsets) + offsets = _instantiate_expression_dict(time, self._offsets, default_dynamic_linear_value=0.0) return {channel: channel_values + offsets[channel] if channel in offsets else channel_values for channel, channel_values in data.items()} @@ -308,7 +312,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return _get_constant_output_channels(self._offsets, input_channels) - + + def contains_dynamic_value(self) -> bool: + return any(isinstance(o,DynamicLinearValue) for o in self._offsets.values()) + class ScalingTransformation(Transformation): __slots__ = ('_factors',) @@ -319,7 +326,7 @@ def __init__(self, factors: Mapping[ChannelID, _TrafoValue]): def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - factors = _instantiate_expression_dict(time, self._factors) + factors = _instantiate_expression_dict(time, self._factors, default_dynamic_linear_value=1.0) return {channel: channel_values * factors[channel] if channel in factors else channel_values for channel, channel_values in data.items()} @@ -352,7 +359,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return _get_constant_output_channels(self._factors, input_channels) - + + def contains_dynamic_value(self) -> bool: + return any(isinstance(o,DynamicLinearValue) for o in self._factors.values()) + try: if TYPE_CHECKING: @@ -437,7 +447,10 @@ def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) - output_channels.add(ch) return output_channels - + + def contains_dynamic_value(self) -> bool: + return any(isinstance(o,DynamicLinearValue) for o in self._channels.values()) + def chain_transformations(*transformations: Transformation) -> Transformation: parsed_transformations = [] @@ -456,12 +469,20 @@ def chain_transformations(*transformations: Transformation) -> Transformation: return ChainedTransformation(*parsed_transformations) -def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue]) -> Mapping[str, Union[Real, np.ndarray]]: +def _instantiate_expression_dict(time, + expressions: Mapping[str, _TrafoValue], + default_dynamic_linear_value: Real, + ) -> Mapping[str, Union[Real, np.ndarray]]: scope = {'t': time} modified_expressions = {} for name, value in expressions.items(): if hasattr(value, 'evaluate_in_scope'): modified_expressions[name] = value.evaluate_in_scope(scope) + if isinstance(value, DynamicLinearValue): + # it is assumed that swept parameters will be handled by the ProgramBuilder accordingly + # such that here only an "identity" trafo is to be applied and the + # trafos are set in the program internally. + modified_expressions[name] = default_dynamic_linear_value if modified_expressions: return {**expressions, **modified_expressions} else: diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py index 2f00f9c27..23e0df90c 100644 --- a/tests/_program/transformation_tests.py +++ b/tests/_program/transformation_tests.py @@ -8,6 +8,7 @@ from qupulse.program.transformation import LinearTransformation, Transformation, IdentityTransformation,\ ChainedTransformation, ParallelChannelTransformation, chain_transformations, OffsetTransformation,\ ScalingTransformation +from qupulse.program.values import DynamicLinearValue class TransformationStub(Transformation): @@ -179,7 +180,11 @@ class IdentityTransformationTests(unittest.TestCase): def test_compare_key(self): with self.assertWarns(DeprecationWarning): self.assertIsNone(IdentityTransformation().compare_key) - + + def test_sweepval(self): + with self.assertRaises(NotImplementedError): + IdentityTransformation().contains_dynamic_value() + def test_singleton(self): self.assertIs(IdentityTransformation(), IdentityTransformation()) @@ -489,6 +494,17 @@ def test_time_dependence(self): }, transformed) + def test_sweepval(self): + channels = {'X': 2, 'Y': DynamicLinearValue(0.1, {'a':0.02})} + trafo = OffsetTransformation(channels) + self.assertEqual(trafo.contains_dynamic_value(), True) + + channels = {'X': 2, 'Y': 2} + trafo = OffsetTransformation(channels) + self.assertEqual(trafo.contains_dynamic_value(), False) + + + class TestScalingTransformation(unittest.TestCase): def setUp(self) -> None: self.constant_scales = {'A': 1.5, 'B': 1.2} @@ -561,3 +577,12 @@ def test_time_dependence(self): 'Z': np.tan(t) * np.exp(t), 'K': values['K'] }, transformed) + + def test_sweepval(self): + channels = {'X': 2, 'Y': DynamicLinearValue(0.1, {'a':0.02})} + trafo = ScalingTransformation(channels) + self.assertEqual(trafo.contains_dynamic_value(), True) + + channels = {'X': 2, 'Y': 2} + trafo = ScalingTransformation(channels) + self.assertEqual(trafo.contains_dynamic_value(), False) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 7f833d7ad..919fdc7d6 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -435,11 +435,11 @@ def test_internal_create_program(self): to_single_waveform = {'something_else'} program_builder = mock.Mock() + expected_transformation = mock.create_autospec(IdentityTransformation,instance=True) with self.assertWarns(DeprecationWarning): - expected_transformation = mock.Mock(spec=IdentityTransformation()) + IdentityTransformation().compare_key - with self.assertWarns(DeprecationWarning): - inner_trafo = mock.Mock(spec=IdentityTransformation()) + inner_trafo = mock.create_autospec(spec=IdentityTransformation,instance=True) inner_trafo.chain.return_value = expected_transformation with mock.patch.object(rhs, '_create_program') as inner_create_program: @@ -595,10 +595,11 @@ def test_build_waveform(self): channel_mapping = dict(a='u', b='v') inner_wf = DummyWaveform(duration=6, defined_channels={'a'}) + trafo = mock.create_autospec(IdentityTransformation,instance=True) with self.assertWarns(DeprecationWarning): # mock will inspect alsod eprecated attributes # TODO: remove assert as soon as attribute is removed - trafo = mock.Mock(spec=IdentityTransformation()) + IdentityTransformation().compare_key arith = ArithmeticPulseTemplate(pt, '-', 6)