Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ A curated evaluation set annotated by board-certified radiologists for validatin

## Acknowledgments

Built on the work of the radiology AI community: [CheXbert](https://github.com/stanfordmlgroup/CheXbert), [RadGraph](https://github.com/jbdel/RadGraph), [BERTScore](https://github.com/Tiiiger/bert_score), [RaTEScore](https://github.com/MAGIC-AI4Med/RaTEScore), [SRR-BERT](https://github.com/StanfordAIMI/SRR-BERT), [GREEN](https://github.com/Stanford-AIMI/GREEN), and datasets like [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/).
Built on the work of the radiology AI community: [CheXbert](https://github.com/stanfordmlgroup/CheXbert), [RadGraph](https://github.com/jbdel/RadGraph), [BERTScore](https://github.com/Tiiiger/bert_score), [RaTEScore](https://github.com/MAGIC-AI4Med/RaTEScore), [SRR-BERT](https://github.com/StanfordAIMI/SRR-BERT), [GREEN](https://github.com/Stanford-AIMI/GREEN), [CRIMSON](https://github.com/rajpurkarlab/CRIMSON), and datasets like [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/).

---
<div align="center">
Expand Down
3 changes: 3 additions & 0 deletions RadEval/RadEval.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self,
crimson_model=None,
crimson_batch_size=1,
crimson_max_concurrent=50,
cache_dir=None,
hoppr_crimson_ct_api="openai",
hoppr_crimson_ct_model=None,
do_per_sample=False,
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(self,
self.crimson_model = crimson_model
self.crimson_batch_size = crimson_batch_size
self.crimson_max_concurrent = crimson_max_concurrent
self.cache_dir = cache_dir
self.hoppr_crimson_ct_api = hoppr_crimson_ct_api
self.hoppr_crimson_ct_model = hoppr_crimson_ct_model
self.do_radeval_bertscore = do_radeval_bertscore
Expand Down Expand Up @@ -184,6 +186,7 @@ def __init__(self,
gemini_api_key=self.gemini_api_key,
batch_size=self.crimson_batch_size,
max_concurrent=self.crimson_max_concurrent,
cache_dir=self.cache_dir,
)
except (ImportError, EnvironmentError, OSError) as e:
warnings.warn(
Expand Down
52 changes: 38 additions & 14 deletions RadEval/metrics/crimson/crimson.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .._llm_base import LLMMetricBase
from .prompt_parts import build_prompt as _build_evaluation_prompt_fn
from .utils import parse_json_response as _parse_json_response_robust

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,8 +114,9 @@ class CRIMSONScore(LLMMetricBase):

SUPPORTED_PROVIDERS: ClassVar[set[str]] = {"openai", "hf"}

DEFAULT_HF_MODEL = "CRIMSONScore/medgemma-4b-it-crimson"
DEFAULT_HF_MODEL = "rajpurkarlab/medgemma-4b-it-crimson"
DEFAULT_OPENAI_MODEL = "gpt-5.2"
DEFAULT_MAX_NEW_TOKENS = 8192

def __init__(
self,
Expand All @@ -125,6 +127,7 @@ def __init__(
device=None,
batch_size=1,
max_concurrent=50,
cache_dir=None,
):
resolved_model = model_name or (
self.DEFAULT_OPENAI_MODEL if provider == "openai"
Expand All @@ -140,6 +143,7 @@ def __init__(
)

self.batch_size = batch_size
self.cache_dir = cache_dir

if provider in ("huggingface", "hf"):
self._init_hf_pipeline()
Expand All @@ -159,13 +163,21 @@ def _init_hf_pipeline(self):
logger.info("Using Flash Attention 2.")
except ImportError:
pass
if self.cache_dir:
model_kwargs["cache_dir"] = self.cache_dir

self.pipe = transformers.pipeline(
"text-generation",
model=self.model_name,
model_kwargs=model_kwargs,
device_map="auto",
)
# Left-pad for efficient batch generation with decoder-only models
if self.pipe.tokenizer.padding_side != "left":
self.pipe.tokenizer.padding_side = "left"
# If the model has a generation_config.json, use it instead of
# injecting our own generation kwargs at inference time.
self._has_generation_config = not self.pipe.model.generation_config._from_model_config
logger.info("Model loaded.")

# ------------------------------------------------------------------
Expand Down Expand Up @@ -194,8 +206,8 @@ def _build_request(self, ref: str, hyp: str, **kwargs) -> dict[str, Any]:
def _parse_response(self, raw: str) -> dict:
cleaned = _extract_json_str(raw)
try:
evaluation = json.loads(cleaned)
except json.JSONDecodeError:
evaluation = _parse_json_response_robust(cleaned)
except (ValueError, json.JSONDecodeError):
repaired = _repair_truncated_json(cleaned)
if repaired is not None:
logger.warning("Repaired truncated CRIMSON JSON.")
Expand Down Expand Up @@ -250,14 +262,23 @@ def _chat_completion(self, request: dict[str, Any]) -> str:
def _hf_generate(self, prompt: str) -> str:
messages = [
{"role": "system", "content": _SYSTEM_MSG},
{"role": "user", "content": prompt + "\nPlease respond with valid JSON only."},
{"role": "user", "content": prompt},
]
outputs = self.pipe(
messages,
max_new_tokens=8192,
do_sample=False,
repetition_penalty=1.1,
)

if self._has_generation_config:
outputs = self.pipe(
messages,
generation_config=self.pipe.model.generation_config,
repetition_penalty=1.1,
)
else:
outputs = self.pipe(
messages,
max_new_tokens=self.DEFAULT_MAX_NEW_TOKENS,
max_length=None,
do_sample=False,
repetition_penalty=1.1,
)
response = outputs[0]["generated_text"][-1]["content"]
if not response:
logger.warning("Empty response from HF pipeline. Raw: %s", outputs)
Expand Down Expand Up @@ -385,6 +406,9 @@ def _nan_fallback():

def _build_evaluation_prompt(self, reference_findings, predicted_findings,
patient_context=None, include_guidelines=True):
# For the MedGemmaCRIMSON, exclude guidelines. The model was trained without them.
if self.provider in ("huggingface", "hf") and self.model_name == self.DEFAULT_HF_MODEL:
include_guidelines = False
return _build_evaluation_prompt_fn(
reference_findings,
predicted_findings,
Expand Down Expand Up @@ -414,15 +438,15 @@ def _calculate_crimson(self, evaluation):

def calculate_weighted_count(error_list, weights=significance_weights,
key="clinical_significance"):
return sum(weights.get(error.get(key, "benign_expected"), 0.0)
return sum(weights.get(error.get(key, ""), 0.25)
for error in error_list)

ref_weight_by_id = {
ref["id"]: significance_weights.get(ref.get("clinical_significance", "benign_expected"), 0.0)
ref["id"]: significance_weights.get(ref.get("clinical_significance", ""), 0.25)
for ref in reference_findings_list
}
pred_weight_by_id = {
pred["id"]: significance_weights.get(pred.get("clinical_significance", "benign_expected"), 0.0)
pred["id"]: significance_weights.get(pred.get("clinical_significance", ""), 0.25)
for pred in predicted_findings_list
}

Expand Down Expand Up @@ -467,7 +491,7 @@ def calculate_weighted_count(error_list, weights=significance_weights,
correct += base_weight
else:
sum_error_weights = sum(
attribute_severity_weights.get(err.get("severity", "negligible"), 0.0)
attribute_severity_weights.get(err.get("severity", ""), 0.25)
for err in finding_attr_errors)
denom = base_weight + sum_error_weights
credit_factor = base_weight / denom if denom > 0 else 0.0
Expand Down
180 changes: 180 additions & 0 deletions RadEval/metrics/crimson/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""JSON parsing utilities for CRIMSON scoring.

Ported from the standalone CRIMSON package to provide robust JSON parsing
for LLM responses that may contain malformed JSON.
"""

import json
import re

# ---------------------------------------------------------------------------
# Regex for fixing invalid JSON backslash escapes
# ---------------------------------------------------------------------------

# Fix invalid JSON backslash escapes (e.g. \_ or \L) that the model may
# produce when echoing back input text. Valid JSON escapes
# (" \\ \/ \b \f \n \r \t \uXXXX) are left untouched.
_INVALID_ESCAPE_RE = re.compile(r'\\(?!["\\\//bfnrtu])')

# ---------------------------------------------------------------------------
# Regex for removing orphan keys in JSON objects
# ---------------------------------------------------------------------------

# Matches a bare string value inside a JSON object — a quoted string preceded
# by a comma and followed by a comma or closing brace, but NOT followed by a
# colon. Example: ,"pred_old", or ,"pred_old"}
# The leading comma is consumed so removal doesn't leave double commas.
_ORPHAN_KEY_AFTER_COMMA_RE = re.compile(r',\s*"[^"]*"\s*(?=\s*[,}])(?!\s*:)')
# Orphan as the first item in an object: {"orphan","key":"val"}
_ORPHAN_KEY_AT_START_RE = re.compile(r'(?<=\{)\s*"[^"]*"\s*,\s*(?=")')

# Malformed array-object hybrid: ["":["R1"] -> ["R1"]
_MALFORMED_ARRAY_OBJ_RE = re.compile(r'\[""\s*:\s*\[')

# Missing opening quote on a JSON key: ,attribute_errors":[ -> ,"attribute_errors":[
_MISSING_OPEN_QUOTE_RE = re.compile(
r'([\]}"0-9]\s*,\s*|{\s*)([a-zA-Z_][a-zA-Z0-9_]*")'
)


# ---------------------------------------------------------------------------
# JSON response parsing
# ---------------------------------------------------------------------------

def _dedupe_keys(pairs):
"""JSON object_pairs_hook that keeps the first occurrence of duplicate keys."""
result = {}
for key, value in pairs:
if key not in result:
result[key] = value
return result


def _loads(text):
"""json.loads with duplicate-key handling (keeps first occurrence)."""
return json.loads(text, object_pairs_hook=_dedupe_keys)


def _is_structural_quote(text, i):
"""Check if the quote at position *i* is likely a JSON structural delimiter.

Returns True for quotes that delimit JSON keys/values (after ``:``, before
``:`` etc.) and False for quotes that are part of prose inside a string
value (e.g. the model quoting a phrase in an explanation).
"""
before = text[i - 1] if i > 0 else ''
after = text[i + 1] if i + 1 < len(text) else ''
# Quotes right after : [ { are always structural (opening a value/key)
if before in (':', '[', '{'):
return True
# Quotes right before : ] } are always structural (closing a key/value)
if after in (':', ']', '}'):
return True
# For comma-adjacent quotes, check deeper context.
if before == ',':
pre_comma = text[i - 2] if i >= 2 else ''
return pre_comma in ('"', ']', '}') or pre_comma.isdigit()
if after == ',':
post_comma = text[i + 2] if i + 2 < len(text) else ''
return post_comma in ('"', '[', '{') or post_comma.isdigit()
return False


def _fix_unescaped_quotes(text, max_attempts=50):
"""Iteratively escape unescaped double-quotes that break JSON parsing.

When the model uses literal " inside a JSON string value (e.g. to quote
a phrase in an explanation), the JSON parser sees it as a string
terminator. This function locates each offending quote by parsing,
catching the error position, searching backwards for a non-structural
quote, escaping it, and retrying.
"""
for _ in range(max_attempts):
try:
return _loads(text)
except json.JSONDecodeError as e:
pos = e.pos
search_start = max(0, pos - 5)
found = False
for i in range(pos, search_start - 1, -1):
if i < len(text) and text[i] == '"' and (i == 0 or text[i - 1] != '\\'):
if _is_structural_quote(text, i):
continue
text = text[:i] + '\\"' + text[i + 1:]
found = True
break
if not found:
raise
raise json.JSONDecodeError("Max quote-fix attempts reached", text, 0)


def _fix_orphan_keys(text):
"""Remove bare string values inside JSON objects (orphan keys with no value).

The model sometimes hallucinates partial key fragments (e.g. ``"pred_old"``)
that appear as bare values in objects, producing invalid JSON like::

{"ref_id":"R3","pred_old","pred_id":"P2"}

This function strips them so the JSON becomes parsable.
"""
text = _ORPHAN_KEY_AFTER_COMMA_RE.sub('', text)
text = _ORPHAN_KEY_AT_START_RE.sub('', text)
return text


def parse_json_response(response, batch_idx=None):
"""Parse model response as JSON, applying progressive fixes.

Fix pipeline (each step only attempted if prior steps fail):
1. Raw parse
2. Remove orphan keys (bare strings in object context)
3. Fix invalid backslash escapes (``\\L``, ``\\_``, etc.)
4. Fix unescaped double-quotes inside string values

Args:
response: Raw model output string.
batch_idx: Optional index for error context in batch evaluation.

Returns:
Parsed JSON object (dict).

Raises:
ValueError: If the response cannot be parsed as JSON.
"""
# 0. Pre-parse text fixes
# Escape curly/smart quotes — they always appear inside JSON string values
response = response.replace('\u201c', '\\"').replace('\u201d', '\\"')
response = response.replace('\u2018', "\\'").replace('\u2019', "\\'")
# Fix malformed array-object hybrids: ["":["R1"] -> ["R1"]
response = _MALFORMED_ARRAY_OBJ_RE.sub('[', response)
# Fix missing opening quote on keys: ,attribute_errors" -> ,"attribute_errors"
response = _MISSING_OPEN_QUOTE_RE.sub(r'\1"\2', response)
# 1. Raw
try:
return _loads(response)
except json.JSONDecodeError:
pass
# 2. Orphan keys (bare strings in object context, e.g. "pred_old")
deorphaned = _fix_orphan_keys(response)
try:
return _loads(deorphaned)
except json.JSONDecodeError:
pass
# 3. Invalid backslash escapes
escaped = _INVALID_ESCAPE_RE.sub(r'\\\\', response)
try:
return _loads(escaped)
except json.JSONDecodeError:
pass
# 4. Unescaped quotes (try on both raw and escape-fixed variants)
for text in (response, escaped):
try:
return _fix_unescaped_quotes(text)
except (json.JSONDecodeError, ValueError):
pass
ctx = f" for batch item {batch_idx}" if batch_idx is not None else ""
raise ValueError(
f"Failed to parse model response as JSON{ctx}\n"
f"Response ({len(response)} chars):\n{response}"
)