Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions app/server/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,22 +191,31 @@ async def save_blob(
return await cls.save_key(store, key, blob, ttl=ttl)

async def init(
self, jurisdiction_id: str, case_id: str, ttl: int = _DEFAULT_TTL
self,
jurisdiction_id: str | bytes,
case_id: str | bytes,
ttl: int = _DEFAULT_TTL,
) -> None:
"""Initialize the case store.

Args:
jurisdiction_id (str): The jurisdiction ID.
case_id (str): The case ID.
jurisdiction_id (str | bytes): The jurisdiction ID.
case_id (str | bytes): The case ID.
ttl (int, optional): The time-to-live in seconds.

Returns:
None
"""
if self.inited:
return
self.jurisdiction_id = jurisdiction_id
self.case_id = case_id
self.jurisdiction_id = (
jurisdiction_id.decode("utf-8")
if isinstance(jurisdiction_id, bytes)
else jurisdiction_id
)
self.case_id = (
case_id.decode("utf-8") if isinstance(case_id, bytes) else case_id
)
self.expires_at = await self._set_expiration(ttl)
logger.debug("CaseStore initialized, will expire at %d", self.expires_at)

Expand Down Expand Up @@ -245,6 +254,8 @@ async def save_masked_name(self, subject_id: str, mask: str) -> None:
Returns:
None
"""
if mask is None:
return
await self.save_masked_names({subject_id: mask})

@ensure_init
Expand All @@ -260,6 +271,10 @@ async def save_masked_names(self, masks: SimpleMapping | IdToMaskMap) -> None:
if not masks:
return
simple_masks = masks._map.copy() if isinstance(masks, IdToMaskMap) else masks
# Guard against None values, which cause errors.
simple_masks = {k: v for k, v in simple_masks.items() if v is not None}
if not simple_masks:
return
mapping_key = self.key("mask")
await self.store.hsetmapping(mapping_key, simple_masks)
await self.store.expire_at(mapping_key, self.expires_at)
Expand All @@ -277,33 +292,46 @@ async def save_placeholders(self, masks: SimpleMapping | NameToMaskMap) -> None:
if not masks:
return
simple_masks = masks._map.copy() if isinstance(masks, NameToMaskMap) else masks

# Guard against None values, which cause errors.
simple_masks = {k: v for k, v in simple_masks.items() if v is not None}
if not simple_masks:
return

mapping_key = self.key("placeholders")
await self.store.hsetmapping(mapping_key, simple_masks)
await self.store.expire_at(mapping_key, self.expires_at)

@ensure_init
async def save_result_doc(self, doc_id: str, doc: OutputDocument) -> None:
async def save_result_doc(self, doc_id: str | bytes, doc: OutputDocument) -> None:
"""Save a result ID for a case.

Args:
doc_id (str): The document ID.
doc_id (str | bytes): The document ID as a string.
doc (OutputDocument): The document.

Returns:
None
"""
if isinstance(doc_id, bytes):
doc_id = doc_id.decode("utf-8")
k = self.key("result:" + doc_id)
serialized_doc = doc.model_dump_json()
await self.store.set(k, serialized_doc)
await self.store.expire_at(k, self.expires_at)

@ensure_init
async def get_result_doc(self, doc_id: str) -> OutputDocument | None:
async def get_result_doc(self, doc_id: str | bytes) -> OutputDocument | None:
"""Get the result ID for a case.

Args:
doc_id (str | bytes): The document ID as a string.

Returns:
OutputDocument: The document.
"""
if isinstance(doc_id, bytes):
doc_id = doc_id.decode("utf-8")
k = self.key("result:" + doc_id)
serialized_doc = await self.store.get(k)
if not serialized_doc:
Expand All @@ -327,6 +355,10 @@ async def save_roles(
# even though it is fully compatible.
# https://stackoverflow.com/a/72841649
srm = cast(SimpleMapping, subject_role_mapping)
# Guard against None values, which cause errors.
srm = {k: v for k, v in srm.items() if v is not None}
if not srm:
return
k = self.key("role")
await self.store.hsetmapping(k, srm)
await self.store.expire_at(k, self.expires_at)
Expand Down Expand Up @@ -444,16 +476,18 @@ async def get_doc_tasks(self) -> dict[str, list[str]]:
return {k.decode(): v.decode().split(",") for k, v in result.items()}

@ensure_init
async def save_doc_task(self, doc_id: str, task: AsyncResult) -> None:
async def save_doc_task(self, doc_id: str | bytes, task: AsyncResult) -> None:
"""Save a document task ID.

Args:
doc_id (str): The document ID.
doc_id (str | bytes): The document ID as a string.
task (AsyncResult): The task result promise.

Returns:
None
"""
if isinstance(doc_id, bytes):
doc_id = doc_id.decode("utf-8")
k = self.key("task")
task_ids = list[str]()
# Flatten the list of task IDs from the result chain
Expand Down Expand Up @@ -496,18 +530,23 @@ async def pop_object(self) -> RedactionTarget | None:

@ensure_init
async def save_real_name(
self, subject_id: str, alias: HumanName, primary: bool = False
self, subject_id: str | bytes, alias: HumanName, primary: bool = False
) -> None:
"""Save a real name for a subject.

Args:
subject_id (str): The subject ID.
subject_id (str | bytes): The subject ID as a string.
alias (HumanName): The alias.
primary (bool, optional): Whether the alias is primary. Defaults to False.

Returns:
None
"""
# NOTE(jnu): we really prefer to have the interface deal exclusively with
# strings, not bytes, but since bytestrings are strewn about the codebase
# we need to support both, as a defensive measure.
if isinstance(subject_id, bytes):
subject_id = subject_id.decode("utf-8")
subject_key = f"aliases:{subject_id}"

if primary:
Expand All @@ -519,7 +558,7 @@ async def save_real_name(
await self.store.expire_at(self.key(subject_key), self.expires_at)

@ensure_init
def key(self, category: str) -> str:
def key(self, category: str | bytes) -> str:
"""Generate a key for a redis value.

Args:
Expand All @@ -528,4 +567,6 @@ def key(self, category: str) -> str:
Returns:
str: The key.
"""
if isinstance(category, bytes):
category = category.decode("utf-8")
return f"{self.jurisdiction_id}:{self.case_id}:{category}"
77 changes: 77 additions & 0 deletions tests/unit/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,80 @@ async def test_get_name_mask_map(fake_redis_store: FakeRedis, config: Config, sp
mask_info = await cs.get_mask_info()
assert mask_info.get_name_mask_map() == spec["expected_name_mask_map"]
assert mask_info.get_id_name_map() == spec["expected_id_name_map"]


async def test_save_info(fake_redis_store: FakeRedis, config: Config):
async with config.queue.store.driver() as store:
async with store.tx() as tx:
cs = CaseStore(tx)
await cs.init("jur1", "case1")
await cs.save_roles({"sub1": "accused"})
await cs.save_masked_names({"sub1": "Accused 1"})
await cs.save_real_name(
"sub1", HumanName(firstName="jack", lastName="doe"), primary=True
)

async with store.tx() as tx:
cs = CaseStore(tx)
await cs.init("jur1", "case1")
mask_info = await cs.get_mask_info()
assert mask_info.get_name_mask_map() == NameToMaskMap(
{"jack doe": "Accused 1"}
)
assert mask_info.get_id_name_map() == IdToNameMap({"sub1": "jack doe"})


async def test_save_null_inferred_data(fake_redis_store: FakeRedis, config: Config):
async with config.queue.store.driver() as store:
async with store.tx() as tx:
cs = CaseStore(tx)
await cs.init("jur1", "case1")
await cs.save_roles({"sub1": "accused", "sub2": "victim"})
await cs.save_real_name(
"sub1", HumanName(firstName="jack", lastName="doe"), primary=True
)
await cs.save_real_name(
"sub2", HumanName(firstName="jane", lastName="doe"), primary=True
)
await cs.save_masked_names({"sub1": "redacted 1", "sub2": None})
await cs.save_placeholders({"subway": "location 1", "target": None})

async with store.tx() as tx:
cs = CaseStore(tx)
await cs.init("jur1", "case1")
mask_info = await cs.get_mask_info()
assert mask_info.get_name_mask_map() == NameToMaskMap(
{
"jack doe": "redacted 1", # overwritten by name text inference
"jane doe": "Victim 1", # inferred from role enumeration
"subway": "location 1", # inferred from placeholder text inference
}
)
assert mask_info.get_id_name_map() == IdToNameMap(
{
"sub1": "jack doe",
"sub2": "jane doe",
}
)


async def test_save_placeholders(fake_redis_store: FakeRedis, config: Config):
async with config.queue.store.driver() as store:
async with store.tx() as tx:
cs = CaseStore(tx)
await cs.init("jur1", "case1")
await cs.save_placeholders({"jack doe": "Accused 99"})
await cs.save_roles({"sub1": "accused"})
await cs.save_masked_names({"sub1": "Accused 1"})
await cs.save_real_name(
"sub1", HumanName(firstName="jack", lastName="doe"), primary=True
)

async with store.tx() as tx:
cs = CaseStore(tx)
await cs.init("jur1", "case1")
mask_info = await cs.get_mask_info()
assert mask_info.get_name_mask_map() == NameToMaskMap(
{"jack doe": "Accused 99"}
)
assert mask_info.get_id_name_map() == IdToNameMap({"sub1": "jack doe"})
Loading