diff --git a/pyter/data.py b/pyter/data.py index 179931b..c4e20ea 100644 --- a/pyter/data.py +++ b/pyter/data.py @@ -290,7 +290,7 @@ class HalfLifeData(AbstractData): log_base: np.ndarray = attrs.Factory(lambda: np.array([10])) well_volume: np.ndarray = attrs.Factory(lambda: np.array([1.0])) false_hit_rate: np.ndarray = attrs.Factory(lambda: np.array([0])) - log_titer_change_other: np.ndarray = attrs.Factory(lambda: np.array([0])) + log_well_change_other: np.ndarray = attrs.Factory(lambda: np.array([0])) well_internal_id_values: dict = attrs.Factory(dict) titer_internal_id_values: dict = attrs.Factory(dict) @@ -302,7 +302,8 @@ class HalfLifeData(AbstractData): id_representative_rows: dict = attrs.Factory(dict) n_values: dict = attrs.Factory(dict) - titer_time: np.ndarray = None + titer_time: np.ndarray | None = None + log_titer_change_other: np.ndarray | None = None def update_internal_ids(self): """ @@ -360,6 +361,10 @@ def update_internal_ids(self): ) self.titer_time = self.well_time[self.id_representative_rows["titer"]] + self.log_titer_change_other = np.broadcast_to( + self.log_well_change_other, + self.well_internal_id_values["titer"].shape, + )[self.id_representative_rows["titer"]] def index_prior_parameters(self): """ diff --git a/test/test_models.py b/test/test_models.py index 44bbde1..a5c5b1f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -40,11 +40,11 @@ ) @pytest.mark.parametrize( "known_change_vec", - [np.array([-1.5, 15.5, 2.6431]), np.array([-0.352]), 0, None], + [np.array([-1.5, 1.5, 2.6431, -3.322]), np.array([-0.352]), 0, None], ) def test_non_inactivation_change(data, known_change_vec): """ - Test that the log_titer_change_other argument + Test that the log_well_change_other argument works as expected for modeling known change in titers due to factors other than inactivation. """ @@ -57,7 +57,7 @@ def test_non_inactivation_change(data, known_change_vec): with seed(rng_seed=5): if known_change_vec is not None: data_change = copy.deepcopy(data) - data_change.log_titer_change_other = known_change_vec + data_change.log_well_change_other = known_change_vec else: data_change = data sim_titers_change, sim_wells_change = model.model( @@ -65,11 +65,12 @@ def test_non_inactivation_change(data, known_change_vec): ) assert all((sim_wells == 0) | (sim_wells == 1)) assert all((sim_wells_change == 0) | (sim_wells_change == 1)) - if known_change_vec is not None: - assert all(sim_titers + known_change_vec == sim_titers_change) - assert not any( - sim_titers + known_change_vec + 0.52 == sim_titers_change - ) + if known_change_vec is None: + expected_change = 0 else: - assert all(sim_titers == sim_titers_change) - assert not any(sim_titers + 0.52 == sim_titers_change) + expected_change = np.broadcast_to( + known_change_vec, + data_change.well_internal_id_values["titer"].shape, + )[data_change.id_representative_rows["titer"]] + assert all(sim_titers + expected_change == sim_titers_change) + assert not any(sim_titers + expected_change + 0.52 == sim_titers_change)