Skip to content

Commit 36691d9

Browse files
committed
polished cern caimira frontned model related methods
1 parent ebde0e1 commit 36691d9

File tree

5 files changed

+72
-56
lines changed

5 files changed

+72
-56
lines changed

caimira/src/caimira/api/controller/virus_report_controller.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ def generate_model(form_obj, data_registry):
1818
return form_obj.build_model(sample_size=sample_size)
1919

2020

21-
def generate_report_results(form_obj, model):
21+
def generate_report_results(form_obj):
2222
return rg.calculate_report_data(
2323
form=form_obj,
24-
model=model,
2524
executor_factory=functools.partial(
2625
concurrent.futures.ThreadPoolExecutor, None, # TODO define report_parallelism
2726
),

caimira/src/caimira/calculator/report/virus_report_data.py

+16-27
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import typing
66
import numpy as np
77
import matplotlib.pyplot as plt
8-
import urllib
9-
import zlib
108

119
from caimira.calculator.models import models, dataclass_utils, profiler, monte_carlo as mc
1210
from caimira.calculator.models.enums import ViralLoads
@@ -123,7 +121,9 @@ def _calculate_co2_concentration(CO2_model, time, fn_name=None):
123121

124122

125123
@profiler.profile
126-
def calculate_report_data(form: VirusFormData, model: models.ExposureModel, executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]:
124+
def calculate_report_data(form: VirusFormData, executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]:
125+
model: models.ExposureModel = form.build_model()
126+
127127
times = interesting_times(model)
128128
short_range_intervals = [interaction.presence.boundaries()[0]
129129
for interaction in model.short_range]
@@ -191,7 +191,7 @@ def calculate_report_data(form: VirusFormData, model: models.ExposureModel, exec
191191
uncertainties_plot(prob, conditional_probability_data)))
192192

193193
return {
194-
"model_repr": repr(model),
194+
"model": model,
195195
"times": list(times),
196196
"exposed_presence_intervals": exposed_presence_intervals,
197197
"short_range_intervals": short_range_intervals,
@@ -330,26 +330,7 @@ def img2base64(img_data) -> str:
330330
return f'data:image/png;base64,{pic_hash}'
331331

332332

333-
def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: VirusFormData):
334-
form_dict = VirusFormData.to_dict(form, strip_defaults=True)
335-
336-
# Generate the calculator URL arguments that would be needed to re-create this
337-
# form.
338-
args = urllib.parse.urlencode(form_dict)
339-
340-
# Then zlib compress + base64 encode the string. To be inverted by the
341-
# /_c/ endpoint.
342-
compressed_args = base64.b64encode(zlib.compress(args.encode())).decode()
343-
qr_url = f"{base_url}{get_root_url()}/_c/{compressed_args}"
344-
url = f"{base_url}{get_root_calculator_url()}?{args}"
345-
346-
return {
347-
'link': url,
348-
'shortened': qr_url,
349-
}
350-
351-
352-
def manufacture_viral_load_scenarios_percentiles(model: mc.ExposureModel) -> typing.Dict[str, mc.ExposureModel]:
333+
def calculate_vl_scenarios_percentiles(model: mc.ExposureModel) -> typing.Dict[str, mc.ExposureModel]:
353334
viral_load = model.concentration_model.infected.virus.viral_load_in_sputum
354335
scenarios = {}
355336
for percentil in (0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99):
@@ -359,7 +340,9 @@ def manufacture_viral_load_scenarios_percentiles(model: mc.ExposureModel) -> typ
359340
)
360341
scenarios[str(vl)] = np.mean(
361342
specific_vl_scenario.infection_probability())
362-
return scenarios
343+
return {
344+
'alternative_viral_load': scenarios,
345+
}
363346

364347

365348
def manufacture_alternative_scenarios(form: VirusFormData) -> typing.Dict[str, mc.ExposureModel]:
@@ -451,7 +434,6 @@ def comparison_report(
451434
form: VirusFormData,
452435
report_data: typing.Dict[str, typing.Any],
453436
scenarios: typing.Dict[str, mc.ExposureModel],
454-
sample_times: typing.List[float],
455437
executor_factory: typing.Callable[[], concurrent.futures.Executor],
456438
):
457439
if (form.short_range_option == "short_range_no"):
@@ -474,7 +456,7 @@ def comparison_report(
474456
results = executor.map(
475457
scenario_statistics,
476458
scenarios.values(),
477-
[sample_times] * len(scenarios),
459+
[report_data['times']] * len(scenarios),
478460
[compute_prob_exposure] * len(scenarios),
479461
timeout=60,
480462
)
@@ -485,3 +467,10 @@ def comparison_report(
485467
return {
486468
'stats': statistics,
487469
}
470+
471+
472+
def alternative_scenarios_data(form: VirusFormData, report_data: typing.Dict[str, typing.Any], executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]:
473+
alternative_scenarios: typing.Dict[str, typing.Any] = manufacture_alternative_scenarios(form=form)
474+
return {
475+
'alternative_scenarios': comparison_report(form=form, report_data=report_data, scenarios=alternative_scenarios, executor_factory=executor_factory)
476+
}

cern_caimira/src/cern_caimira/apps/calculator/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ async def post(self) -> None:
246246
max_workers=self.settings['handler_worker_pool_size'],
247247
timeout=300,
248248
)
249-
model = virus_report_controller.generate_model(form, data_registry)
250-
report_data_task = executor.submit(calculate_report_data, form, model,
249+
report_data_task = executor.submit(calculate_report_data, form,
251250
executor_factory=functools.partial(
252251
concurrent.futures.ThreadPoolExecutor,
253252
self.settings['report_generation_parallelism'],

cern_caimira/src/cern_caimira/apps/calculator/report/virus_report.py

+50-19
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import json
66
import typing
77
import jinja2
8+
import urllib
9+
import zlib
10+
import base64
811
import numpy as np
912

1013
from .. import markdown_tools
1114

1215
from caimira.calculator.models import models
1316
from caimira.calculator.validators.virus.virus_validator import VirusFormData
14-
from caimira.calculator.report.virus_report_data import calculate_report_data, interesting_times, manufacture_alternative_scenarios, manufacture_viral_load_scenarios_percentiles, comparison_report, generate_permalink
17+
from caimira.calculator.report.virus_report_data import alternative_scenarios_data, calculate_report_data, calculate_vl_scenarios_percentiles
1518

1619

1720
def minutes_to_time(minutes: int) -> str:
@@ -62,6 +65,25 @@ def non_zero_percentage(percentage: int) -> str:
6265
return "{:0.1f}%".format(percentage)
6366

6467

68+
def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: VirusFormData):
69+
form_dict = VirusFormData.to_dict(form, strip_defaults=True)
70+
71+
# Generate the calculator URL arguments that would be needed to re-create this
72+
# form.
73+
args = urllib.parse.urlencode(form_dict)
74+
75+
# Then zlib compress + base64 encode the string. To be inverted by the
76+
# /_c/ endpoint.
77+
compressed_args = base64.b64encode(zlib.compress(args.encode())).decode()
78+
qr_url = f"{base_url}{get_root_url()}/_c/{compressed_args}"
79+
url = f"{base_url}{get_root_calculator_url()}?{args}"
80+
81+
return {
82+
'link': url,
83+
'shortened': qr_url,
84+
}
85+
86+
6587
@dataclasses.dataclass
6688
class VirusReportGenerator:
6789
jinja_loader: jinja2.BaseLoader
@@ -74,44 +96,53 @@ def build_report(
7496
form: VirusFormData,
7597
executor_factory: typing.Callable[[], concurrent.futures.Executor],
7698
) -> str:
77-
model = form.build_model()
7899
context = self.prepare_context(
79-
base_url, model, form, executor_factory=executor_factory)
100+
base_url, form, executor_factory=executor_factory)
80101
return self.render(context)
81102

82103
def prepare_context(
83104
self,
84105
base_url: str,
85-
model: models.ExposureModel,
86106
form: VirusFormData,
87107
executor_factory: typing.Callable[[], concurrent.futures.Executor],
88108
) -> dict:
89109
now = datetime.utcnow().astimezone()
90110
time = now.strftime("%Y-%m-%d %H:%M:%S UTC")
91111

92-
data_registry_version = f"v{model.data_registry.version}" if model.data_registry.version else None
93112
context = {
94-
'model': model,
95113
'form': form,
96114
'creation_date': time,
97-
'data_registry_version': data_registry_version,
98115
}
99116

100-
scenario_sample_times = interesting_times(model)
101-
report_data = calculate_report_data(
102-
form, model, executor_factory=executor_factory)
117+
# Main report data
118+
report_data = calculate_report_data(form, executor_factory)
103119
context.update(report_data)
104120

105-
alternative_scenarios = manufacture_alternative_scenarios(form)
106-
context['alternative_viral_load'] = manufacture_viral_load_scenarios_percentiles(
107-
model) if form.conditional_probability_viral_loads else None
108-
context['alternative_scenarios'] = comparison_report(
109-
form, report_data, alternative_scenarios, scenario_sample_times, executor_factory=executor_factory,
110-
)
111-
context['permalink'] = generate_permalink(
121+
# Model and Data Registry
122+
model: models.ExposureModel = report_data['model']
123+
data_registry_version: typing.Optional[str] = f"v{model.data_registry.version}" if model.data_registry.version else None
124+
125+
# Alternative scenarios data
126+
alternative_scenarios: typing.Dict[str,typing.Any] = alternative_scenarios_data(form, report_data, executor_factory)
127+
context.update(alternative_scenarios)
128+
129+
# Alternative viral load data
130+
if form.conditional_probability_viral_loads:
131+
alternative_viral_load: typing.Dict[str,typing.Any] = calculate_vl_scenarios_percentiles(model)
132+
context.update(alternative_viral_load)
133+
134+
# Permalink
135+
permalink: typing.Dict[str, str] = generate_permalink(
112136
base_url, self.get_root_url, self.get_root_calculator_url, form)
113-
context['get_url'] = self.get_root_url
114-
context['get_calculator_url'] = self.get_root_calculator_url
137+
138+
# URLs (root, calculator and permalink)
139+
context.update({
140+
'model_repr': repr(model),
141+
'data_registry_version': data_registry_version,
142+
'permalink': permalink,
143+
'get_url': self.get_root_url,
144+
'get_calculator_url': self.get_root_calculator_url,
145+
})
115146

116147
return context
117148

cern_caimira/tests/test_report_generator.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,21 @@ def test_interesting_times_w_temp(exposure_model_w_outside_temp_changes):
104104
np.testing.assert_allclose(result, expected)
105105

106106

107-
def test_expected_new_cases(baseline_form_with_sr: VirusFormData):
108-
model = baseline_form_with_sr.build_model()
109-
107+
def test_expected_new_cases(baseline_form_with_sr: VirusFormData):
110108
executor_factory = partial(
111109
concurrent.futures.ThreadPoolExecutor, 1,
112110
)
113111

114112
# Short- and Long-range contributions
115-
report_data = rep_gen.calculate_report_data(baseline_form_with_sr, model, executor_factory)
113+
report_data = rep_gen.calculate_report_data(baseline_form_with_sr, executor_factory)
116114
sr_lr_expected_new_cases = report_data['expected_new_cases']
117115
sr_lr_prob_inf = report_data['prob_inf']/100
118116

119117
# Long-range contributions alone
120-
scenario_sample_times = rep_gen.interesting_times(model)
118+
scenario_sample_times = report_data['times']
121119
alternative_scenarios = rep_gen.manufacture_alternative_scenarios(baseline_form_with_sr)
122120
alternative_statistics = rep_gen.comparison_report(
123-
baseline_form_with_sr, report_data, alternative_scenarios, scenario_sample_times, executor_factory=executor_factory,
121+
baseline_form_with_sr, report_data, alternative_scenarios, executor_factory=executor_factory,
124122
)
125123

126124
lr_expected_new_cases = alternative_statistics['stats']['Base scenario without short-range interactions']['expected_new_cases']

0 commit comments

Comments
 (0)