Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyter/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class HalfLifeData(AbstractData):
log_base: np.ndarray = attrs.Factory(lambda: np.array([10]))
well_volume: np.ndarray = attrs.Factory(lambda: np.array([1.0]))
false_hit_rate: np.ndarray = attrs.Factory(lambda: np.array([0]))
log_titer_change_other: np.ndarray = attrs.Factory(lambda: np.array([0]))
log_well_change_other: np.ndarray = attrs.Factory(lambda: np.array([0]))

well_internal_id_values: dict = attrs.Factory(dict)
titer_internal_id_values: dict = attrs.Factory(dict)
Expand All @@ -302,7 +302,8 @@ class HalfLifeData(AbstractData):
id_representative_rows: dict = attrs.Factory(dict)
n_values: dict = attrs.Factory(dict)

titer_time: np.ndarray = None
titer_time: np.ndarray | None = None
log_titer_change_other: np.ndarray | None = None

def update_internal_ids(self):
"""
Expand Down Expand Up @@ -360,6 +361,10 @@ def update_internal_ids(self):
)

self.titer_time = self.well_time[self.id_representative_rows["titer"]]
self.log_titer_change_other = np.broadcast_to(
self.log_well_change_other,
self.well_internal_id_values["titer"].shape,
)[self.id_representative_rows["titer"]]

def index_prior_parameters(self):
"""
Expand Down
21 changes: 11 additions & 10 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
)
@pytest.mark.parametrize(
"known_change_vec",
[np.array([-1.5, 15.5, 2.6431]), np.array([-0.352]), 0, None],
[np.array([-1.5, 1.5, 2.6431, -3.322]), np.array([-0.352]), 0, None],
)
def test_non_inactivation_change(data, known_change_vec):
"""
Test that the log_titer_change_other argument
Test that the log_well_change_other argument
works as expected for modeling known change in
titers due to factors other than inactivation.
"""
Expand All @@ -57,19 +57,20 @@ def test_non_inactivation_change(data, known_change_vec):
with seed(rng_seed=5):
if known_change_vec is not None:
data_change = copy.deepcopy(data)
data_change.log_titer_change_other = known_change_vec
data_change.log_well_change_other = known_change_vec
else:
data_change = data
sim_titers_change, sim_wells_change = model.model(
data=data_change.freeze()
)
assert all((sim_wells == 0) | (sim_wells == 1))
assert all((sim_wells_change == 0) | (sim_wells_change == 1))
if known_change_vec is not None:
assert all(sim_titers + known_change_vec == sim_titers_change)
assert not any(
sim_titers + known_change_vec + 0.52 == sim_titers_change
)
if known_change_vec is None:
expected_change = 0
else:
assert all(sim_titers == sim_titers_change)
assert not any(sim_titers + 0.52 == sim_titers_change)
expected_change = np.broadcast_to(
known_change_vec,
data_change.well_internal_id_values["titer"].shape,
)[data_change.id_representative_rows["titer"]]
assert all(sim_titers + expected_change == sim_titers_change)
assert not any(sim_titers + expected_change + 0.52 == sim_titers_change)