diff --git a/qupulse/utils/__init__.py b/qupulse/utils/__init__.py index 8655cd15..fa571537 100644 --- a/qupulse/utils/__init__.py +++ b/qupulse/utils/__init__.py @@ -13,6 +13,7 @@ from qupulse.expressions import ExpressionScalar, ExpressionLike import numpy +import sympy as sp try: from math import isclose @@ -130,7 +131,7 @@ def forced_hash(obj) -> int: def to_next_multiple(sample_rate: ExpressionLike, quantum: int, - min_quanta: Optional[int] = None) -> Callable[[ExpressionLike],ExpressionScalar]: + min_quanta: Optional[int] = None) -> Callable[[ExpressionScalar], ExpressionLike]: """Construct a helper function to expand a duration to one corresponding to valid sample multiples according to the arguments given. Useful e.g. for PulseTemplate.pad_to's 'to_new_duration'-argument. @@ -140,17 +141,34 @@ def to_next_multiple(sample_rate: ExpressionLike, quantum: int, quantum: number of samples to whose next integer multiple the duration shall be rounded up to. min_quanta: number of multiples of quantum not to fall short of. Returns: - A function that takes a duration (ExpressionLike) as input, and returns + A function that takes a duration as input, and returns a duration rounded up to the next valid samples count in given sample rate. The function returns 0 if duration==0, <0 is not checked if min_quanta is None. """ sample_rate = ExpressionScalar(sample_rate) + #is it more efficient to omit the Max call if not necessary? if min_quanta is None: #double negative for ceil division. return lambda duration: -(-(duration*sample_rate)//quantum) * (quantum/sample_rate) else: - #still return 0 if duration==0 - return lambda duration: ExpressionScalar(f'{quantum}/({sample_rate})*Max({min_quanta},-(-({duration})*{sample_rate}//{quantum}))*Max(0, sign({duration}))') - \ No newline at end of file + # work with sympy + sample_rate = sample_rate.sympified_expression + duration_per_quantum = sp.Integer(quantum) / sample_rate + minimal_duration = duration_per_quantum * min_quanta + + def build_next_multiple(duration: ExpressionScalar) -> ExpressionLike: + duration = sp.sympify(duration) + n_quanta = sp.ceiling(duration / duration_per_quantum) + rounded_up_duration = n_quanta * duration_per_quantum + + next_multiple_sp = sp.Piecewise( + (0, sp.Le(n_quanta, 0)), + (minimal_duration, sp.Le(n_quanta, min_quanta)), + (rounded_up_duration, True), + evaluate=False, + ) + return ExpressionScalar(next_multiple_sp) + + return build_next_multiple diff --git a/tests/utils/utils_tests.py b/tests/utils/utils_tests.py index 3675109a..a6d880e8 100644 --- a/tests/utils/utils_tests.py +++ b/tests/utils/utils_tests.py @@ -120,12 +120,22 @@ def test_to_next_multiple(self): self.assertEqual(evaluated, expected) duration = 6185240.0000001 - evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration) + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration).evaluate_numeric() expected = 6185248 self.assertEqual(evaluated, expected) + duration = 63.99 + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=4)(duration).evaluate_numeric() + expected = 64 + self.assertEqual(evaluated, expected) + + duration = 64.01 + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=4)(duration).evaluate_numeric() + expected = 80 + self.assertEqual(evaluated, expected) + duration = 0. - evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration) + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration).evaluate_numeric() expected = 0. self.assertEqual(evaluated, expected) @@ -139,12 +149,18 @@ def test_to_next_multiple(self): dict(q=3.14159,w=1.0)) expected = 16. self.assertEqual(evaluated, expected) - - #bracket silent bug - duration = ExpressionScalar('51 + q*51') - evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=1)(duration).evaluate_in_scope( - dict(q=3.14159,)) - expected = 224. - self.assertEqual(evaluated, expected) - - \ No newline at end of file + + +def test_to_next_multiple_padding_duration_evaluation(benchmark): + # reminder how to manually run pytest tests: + # use pytest -k test_to_next_multiple_padding_duration_evaluation + # or for faster collection phase + # pytest -k test_to_next_multiple_padding_duration_evaluation tests/utils/utils_tests.py + + from qupulse.pulses import FunctionPT + pt = FunctionPT('start+t/t_gate*(end-start)', 't_gate', 'a') + + def padding(): + pt.pad_to(to_next_multiple(2.4, 16, 4)).duration.evaluate_in_scope({'t_gate': 10.}) + + benchmark(padding)