Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 61 additions & 38 deletions person_story.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Story generators for the CC HIC OMOP schema."""
import datetime as dt
from typing import Callable, Generator, Optional, Union, cast

from sqlsynthgen.utils import generate_time_series
import numpy as np
from mimesis import Generic
import random
Expand Down Expand Up @@ -155,65 +155,84 @@ def gen_blood_pressure_events( # pylint: disable=too-many-arguments
tables (measurements, observation, etc.).
"""

def populate_blood_pressure_values(

def generate_paired_measurement(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@myyong, I've done some abstraction in this function. I didn't go too far, but it may be one step toward #184.

person_id: int,
visit_occurrence_id: int,
event_datetime: dt.datetime,
values: tuple[float, float],
measurement_concept_id: tuple[int,int],
measurement_type_concept_ids: int,
unit_concept_id: int,
unit_source_value: str,
) -> tuple[SqlRow, SqlRow]:

Systolic_blood_pressure_by_Noninvasive = 21492239
Diastolic_blood_pressure_by_Noninvasive = 21492240
measurement_type_concept_id = 32817 # EHR measurement
avg_systolic = 114.236842
avg_diastolic = 74.447368
avg_difference = avg_systolic - avg_diastolic
unit_concept_id = 8876 # mmHg

gender = cast(int, person["gender_concept_id"])
if gender == 8507:
systolic_value = random_normal(src_stats["bp_profile"][0]["average_under_60_systolic"],src_stats["bp_profile"][0]["stddev_under_60_systolic"])
diastolic_value = src_stats["bp_profile"][0]["average_systolic_diastolic_difference"] + systolic_value
elif gender == 8532:
systolic_value = random_normal(src_stats["bp_profile"][1]["average_under_60_systolic"],src_stats["bp_profile"][1]["stddev_under_60_systolic"])
diastolic_value = src_stats["bp_profile"][1]["average_systolic_diastolic_difference"] + systolic_value
else:
systolic_value = avg_systolic
diastolic_value = avg_diastolic


### This can be abastracted to generate any number of set of measurements
"""Generate two rows for the measurement table."""
systolic: SqlRow = {
"measurement_concept_id": cast(int, Systolic_blood_pressure_by_Noninvasive),
measurement1: SqlRow = {
"measurement_concept_id": cast(int, measurement_concept_id[0]),
"person_id": person_id,
"visit_occurrence_id": visit_occurrence_id,
"measurement_datetime": event_datetime,
"measurement_date": event_datetime.date(),
"measurement_type_concept_id": measurement_type_concept_id,
"measurement_type_concept_id": measurement_type_concept_ids,
"unit_concept_id": unit_concept_id,
"unit_source_value": "mmHg",
"value_as_number": systolic_value,
"unit_source_value": unit_source_value,
"value_as_number": values[0],
}

diastolic: SqlRow = {
"measurement_concept_id": cast(int, Diastolic_blood_pressure_by_Noninvasive),
measurement2: SqlRow = {
"measurement_concept_id": cast(int, measurement_concept_id[1]),
"person_id": person_id,
"visit_occurrence_id": visit_occurrence_id,
"measurement_datetime": event_datetime,
"measurement_date": event_datetime.date(),
"measurement_type_concept_id": measurement_type_concept_id,
"measurement_type_concept_id": measurement_type_concept_ids,
"unit_concept_id": unit_concept_id,
"unit_source_value": "mmHg",
"value_as_number": diastolic_value,
"unit_source_value": unit_source_value,
"value_as_number": values[1],
}
return systolic, diastolic
return measurement1, measurement2

event_datetimes = random_event_times(avg_rate, visit_occurrence)

avg_systolic = 114.236842
avg_diastolic = 74.447368
sys_bp_non_invasive_concept_id = 21492239
dias_bp_non_invasive_concept_id = 21492240
measurement_type_concept_id = 32817 # EHR measurement
unit_source_value = "mmHg"
unit_concept_id = 8876 # mmHg

gender = cast(int, person["gender_concept_id"])
if gender == 8507:
systolic_value = np.round(generate_time_series(len(event_datetimes), 'iid',
{'mean': src_stats["bp_profile"][0]["average_under_60_systolic"],
'std': src_stats["bp_profile"][0]["stddev_under_60_systolic"]},
random_state=42))
diastolic_value = np.round(random_normal(src_stats["bp_profile"][0]["average_systolic_diastolic_difference"],src_stats["bp_profile"][0]["average_systolic_diastolic_difference"]*0.1) + systolic_value)
elif gender == 8532:
systolic_value = np.round(generate_time_series(len(event_datetimes), 'iid',
{'mean': src_stats["bp_profile"][1]["average_under_60_systolic"],
'std': src_stats["bp_profile"][1]["stddev_under_60_systolic"]},
random_state=42))
diastolic_value = np.round(random_normal(src_stats["bp_profile"][1]["average_systolic_diastolic_difference"],
src_stats["bp_profile"][1][
"average_systolic_diastolic_difference"] * 0.1) + systolic_value)
else:
systolic_value = avg_systolic
diastolic_value = avg_diastolic

events: list[tuple[str, SqlRow]] = []
for event_datetime in sorted(event_datetimes):
systolic, diastolic = populate_blood_pressure_values(cast(int, person["person_id"]),
for index, event_datetime in enumerate(sorted(event_datetimes)):
systolic_dict, diastolic_dict = generate_paired_measurement(cast(int, person["person_id"]),
cast(int, visit_occurrence["visit_occurrence_id"]),
event_datetime)
events.append(("measurement", systolic))
events.append(("measurement", diastolic))
event_datetime,(systolic_value[index], diastolic_value[index]),
(sys_bp_non_invasive_concept_id,dias_bp_non_invasive_concept_id),
measurement_type_concept_id,unit_concept_id,unit_source_value)
events.append(("measurement", systolic_dict)),
events.append(("measurement", diastolic_dict))
return events

def generate(
Expand All @@ -240,11 +259,15 @@ def generate(
death_row = (yield death) if death else None
visit_occurrence = yield gen_visit_occurrence(person, death_row, src_stats)

# abs to avoid negative rates due to random normal variation
# abs to avoid negative rates due to random normal variation
avg_rate = abs(random_normal(
src_stats["avg_measurements_per_visit_hour"][0]['avg_measurements_per_hour'],
src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'] ))
src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'] )
)


print(f"Generating blood pressure events at an average rate of {avg_rate} per hour.")
for event in gen_blood_pressure_events(
avg_rate,
visit_occurrence,
Expand Down
207 changes: 205 additions & 2 deletions sqlsynthgen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from importlib import import_module
from pathlib import Path
from types import ModuleType
from typing import Any, Final, Mapping, Optional, Union

from typing import Any, Final, Mapping, Optional, Union, Literal, Dict
import numpy as np
import yaml
from jsonschema.exceptions import ValidationError
from jsonschema.validators import validate
Expand Down Expand Up @@ -179,3 +179,206 @@ def conf_logger(verbose: bool) -> None:

logger.addHandler(stdout_handler)
logger.addHandler(stderr_handler)


def generate_time_series(
N: int,
model_option: Literal["iid", "random_walk", "ar1"],
model_params: Dict[str, Any],
random_state: int = 42
) -> np.ndarray:
"""
Generate a synthetic time series using one of three simple models.

Parameters
----------
N : int
Number of time steps.
model_option : {"iid", "random_walk", "ar1"}
Which model to use.
model_params : dict
Dictionary of parameters. Expected keys:

For all models:
- "mean": float
- "std": float

For random_walk:
- "drift": float
- "epsilon_std": float

For ar1:
- "mu": float
- "phi": float
- "epsilon_std": float

random_state : int or None
Optional random seed.

Returns
-------
np.ndarray
Synthetic time series of length N.
"""

rng = np.random.default_rng(random_state)

# Initialise x0 from marginal distribution
x0: float = rng.normal(
loc=model_params["mean"],
scale=model_params["std"]
)

# ----------------------------
# MODEL 1: IID Gaussian
# ----------------------------
if model_option == "iid":
return sample_iid_gaussian(
N=N,
mu=model_params["mean"],
sigma=model_params["std"],
rng=rng,
)

# ----------------------------
# MODEL 2: Random Walk
# ----------------------------
if model_option == "random_walk":
required = ["drift", "epsilon_std"]
for key in required:
if key not in model_params:
raise KeyError(f"src_stats must contain '{key}' for random_walk")

return random_walk_with_drift(
N=N,
x0=x0,
drift=model_params["drift"],
sigma_eps=model_params["epsilon_std"],
rng=rng,
)

# ----------------------------
# MODEL 3: AR(1)
# ----------------------------
if model_option == "ar1":
required = ["mu", "phi", "epsilon_std"]
for key in required:
if key not in model_params:
raise KeyError(f"src_stats must contain '{key}' for ar1")

return ar1_process(
N=N,
x0=x0,
mu=model_params["mu"],
phi=model_params["phi"],
sigma_eps=model_params["epsilon_std"],
rng=rng,
)

# ----------------------------
raise ValueError(f"Unknown model_option: {model_option!r}")


def sample_iid_gaussian(
N: int,
mu: float,
sigma: float,
rng: np.random.Generator
) -> np.ndarray:
""""
Generate an IID Gaussian time series.

Parameters
----------
N : int
Length of the time series.
mu : float
Mean of the Gaussian.
sigma : float
Standard deviation of the Gaussian.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Generated IID Gaussian time series of length N
"""
return rng.normal(loc=mu, scale=sigma, size=N)



def random_walk_with_drift(
N: int,
x0: float,
drift: float,
sigma_eps: float,
rng: np.random.Generator
) -> np.ndarray:
"""
Generate a random walk time series with drift.

Parameters
----------
N : int
Length of the time series.
x0 : float
Initial value of the time series.
drift : float
Drift term added at each time step.
sigma_eps : float
Standard deviation of the white noise.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Generated random walk time series of length N


"""
x = np.empty(N)
x[0] = x0
for t in range(1, N):
x[t] = x[t-1] + drift + rng.normal(0.0, sigma_eps)
return x


def ar1_process(
N: int,
x0: float,
mu: float,
phi: float,
sigma_eps: float,
rng: np.random.Generator
) -> np.ndarray:
"""
Generate an AR(1) time series.
An AR(1) process is defined by the equation:
x[t] = mu + phi * (x[t-1] - mu) + eps[t]
where eps[t] ~ N(0, sigma_eps^2)

Parameters
----------
N : int
Length of the time series.
x0 : float
Initial value of the time series.
mu : float
Mean of the AR(1) process.
phi : float
Autoregressive coefficient.
sigma_eps : float
Standard deviation of the white noise.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Generated AR(1) time series of length N
"""
x = np.empty(N)
x[0] = x0
for t in range(1, N):
eps = rng.normal(0.0, sigma_eps)
x[t] = mu + phi * (x[t-1] - mu) + eps
return x
Loading