diff --git a/README.md b/README.md
index c7b71d3..cfd626d 100755
--- a/README.md
+++ b/README.md
@@ -234,7 +234,7 @@ A curated evaluation set annotated by board-certified radiologists for validatin
## Acknowledgments
-Built on the work of the radiology AI community: [CheXbert](https://github.com/stanfordmlgroup/CheXbert), [RadGraph](https://github.com/jbdel/RadGraph), [BERTScore](https://github.com/Tiiiger/bert_score), [RaTEScore](https://github.com/MAGIC-AI4Med/RaTEScore), [SRR-BERT](https://github.com/StanfordAIMI/SRR-BERT), [GREEN](https://github.com/Stanford-AIMI/GREEN), and datasets like [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/).
+Built on the work of the radiology AI community: [CheXbert](https://github.com/stanfordmlgroup/CheXbert), [RadGraph](https://github.com/jbdel/RadGraph), [BERTScore](https://github.com/Tiiiger/bert_score), [RaTEScore](https://github.com/MAGIC-AI4Med/RaTEScore), [SRR-BERT](https://github.com/StanfordAIMI/SRR-BERT), [GREEN](https://github.com/Stanford-AIMI/GREEN), [CRIMSON](https://github.com/rajpurkarlab/CRIMSON), and datasets like [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/).
---
diff --git a/RadEval/RadEval.py b/RadEval/RadEval.py
index 2d2eb47..1bab4dc 100755
--- a/RadEval/RadEval.py
+++ b/RadEval/RadEval.py
@@ -38,6 +38,7 @@ def __init__(self,
crimson_model=None,
crimson_batch_size=1,
crimson_max_concurrent=50,
+ cache_dir=None,
hoppr_crimson_ct_api="openai",
hoppr_crimson_ct_model=None,
do_per_sample=False,
@@ -76,6 +77,7 @@ def __init__(self,
self.crimson_model = crimson_model
self.crimson_batch_size = crimson_batch_size
self.crimson_max_concurrent = crimson_max_concurrent
+ self.cache_dir = cache_dir
self.hoppr_crimson_ct_api = hoppr_crimson_ct_api
self.hoppr_crimson_ct_model = hoppr_crimson_ct_model
self.do_radeval_bertscore = do_radeval_bertscore
@@ -184,6 +186,7 @@ def __init__(self,
gemini_api_key=self.gemini_api_key,
batch_size=self.crimson_batch_size,
max_concurrent=self.crimson_max_concurrent,
+ cache_dir=self.cache_dir,
)
except (ImportError, EnvironmentError, OSError) as e:
warnings.warn(
diff --git a/RadEval/metrics/crimson/crimson.py b/RadEval/metrics/crimson/crimson.py
index ea76dc7..aa735ce 100644
--- a/RadEval/metrics/crimson/crimson.py
+++ b/RadEval/metrics/crimson/crimson.py
@@ -21,6 +21,7 @@
from .._llm_base import LLMMetricBase
from .prompt_parts import build_prompt as _build_evaluation_prompt_fn
+from .utils import parse_json_response as _parse_json_response_robust
logger = logging.getLogger(__name__)
@@ -113,8 +114,9 @@ class CRIMSONScore(LLMMetricBase):
SUPPORTED_PROVIDERS: ClassVar[set[str]] = {"openai", "hf"}
- DEFAULT_HF_MODEL = "CRIMSONScore/medgemma-4b-it-crimson"
+ DEFAULT_HF_MODEL = "rajpurkarlab/medgemma-4b-it-crimson"
DEFAULT_OPENAI_MODEL = "gpt-5.2"
+ DEFAULT_MAX_NEW_TOKENS = 8192
def __init__(
self,
@@ -125,6 +127,7 @@ def __init__(
device=None,
batch_size=1,
max_concurrent=50,
+ cache_dir=None,
):
resolved_model = model_name or (
self.DEFAULT_OPENAI_MODEL if provider == "openai"
@@ -140,6 +143,7 @@ def __init__(
)
self.batch_size = batch_size
+ self.cache_dir = cache_dir
if provider in ("huggingface", "hf"):
self._init_hf_pipeline()
@@ -159,6 +163,8 @@ def _init_hf_pipeline(self):
logger.info("Using Flash Attention 2.")
except ImportError:
pass
+ if self.cache_dir:
+ model_kwargs["cache_dir"] = self.cache_dir
self.pipe = transformers.pipeline(
"text-generation",
@@ -166,6 +172,12 @@ def _init_hf_pipeline(self):
model_kwargs=model_kwargs,
device_map="auto",
)
+ # Left-pad for efficient batch generation with decoder-only models
+ if self.pipe.tokenizer.padding_side != "left":
+ self.pipe.tokenizer.padding_side = "left"
+ # If the model has a generation_config.json, use it instead of
+ # injecting our own generation kwargs at inference time.
+ self._has_generation_config = not self.pipe.model.generation_config._from_model_config
logger.info("Model loaded.")
# ------------------------------------------------------------------
@@ -194,8 +206,8 @@ def _build_request(self, ref: str, hyp: str, **kwargs) -> dict[str, Any]:
def _parse_response(self, raw: str) -> dict:
cleaned = _extract_json_str(raw)
try:
- evaluation = json.loads(cleaned)
- except json.JSONDecodeError:
+ evaluation = _parse_json_response_robust(cleaned)
+ except (ValueError, json.JSONDecodeError):
repaired = _repair_truncated_json(cleaned)
if repaired is not None:
logger.warning("Repaired truncated CRIMSON JSON.")
@@ -250,14 +262,23 @@ def _chat_completion(self, request: dict[str, Any]) -> str:
def _hf_generate(self, prompt: str) -> str:
messages = [
{"role": "system", "content": _SYSTEM_MSG},
- {"role": "user", "content": prompt + "\nPlease respond with valid JSON only."},
+ {"role": "user", "content": prompt},
]
- outputs = self.pipe(
- messages,
- max_new_tokens=8192,
- do_sample=False,
- repetition_penalty=1.1,
- )
+
+ if self._has_generation_config:
+ outputs = self.pipe(
+ messages,
+ generation_config=self.pipe.model.generation_config,
+ repetition_penalty=1.1,
+ )
+ else:
+ outputs = self.pipe(
+ messages,
+ max_new_tokens=self.DEFAULT_MAX_NEW_TOKENS,
+ max_length=None,
+ do_sample=False,
+ repetition_penalty=1.1,
+ )
response = outputs[0]["generated_text"][-1]["content"]
if not response:
logger.warning("Empty response from HF pipeline. Raw: %s", outputs)
@@ -385,6 +406,9 @@ def _nan_fallback():
def _build_evaluation_prompt(self, reference_findings, predicted_findings,
patient_context=None, include_guidelines=True):
+ # For the MedGemmaCRIMSON, exclude guidelines. The model was trained without them.
+ if self.provider in ("huggingface", "hf") and self.model_name == self.DEFAULT_HF_MODEL:
+ include_guidelines = False
return _build_evaluation_prompt_fn(
reference_findings,
predicted_findings,
@@ -414,15 +438,15 @@ def _calculate_crimson(self, evaluation):
def calculate_weighted_count(error_list, weights=significance_weights,
key="clinical_significance"):
- return sum(weights.get(error.get(key, "benign_expected"), 0.0)
+ return sum(weights.get(error.get(key, ""), 0.25)
for error in error_list)
ref_weight_by_id = {
- ref["id"]: significance_weights.get(ref.get("clinical_significance", "benign_expected"), 0.0)
+ ref["id"]: significance_weights.get(ref.get("clinical_significance", ""), 0.25)
for ref in reference_findings_list
}
pred_weight_by_id = {
- pred["id"]: significance_weights.get(pred.get("clinical_significance", "benign_expected"), 0.0)
+ pred["id"]: significance_weights.get(pred.get("clinical_significance", ""), 0.25)
for pred in predicted_findings_list
}
@@ -467,7 +491,7 @@ def calculate_weighted_count(error_list, weights=significance_weights,
correct += base_weight
else:
sum_error_weights = sum(
- attribute_severity_weights.get(err.get("severity", "negligible"), 0.0)
+ attribute_severity_weights.get(err.get("severity", ""), 0.25)
for err in finding_attr_errors)
denom = base_weight + sum_error_weights
credit_factor = base_weight / denom if denom > 0 else 0.0
diff --git a/RadEval/metrics/crimson/utils.py b/RadEval/metrics/crimson/utils.py
new file mode 100644
index 0000000..2460f50
--- /dev/null
+++ b/RadEval/metrics/crimson/utils.py
@@ -0,0 +1,180 @@
+"""JSON parsing utilities for CRIMSON scoring.
+
+Ported from the standalone CRIMSON package to provide robust JSON parsing
+for LLM responses that may contain malformed JSON.
+"""
+
+import json
+import re
+
+# ---------------------------------------------------------------------------
+# Regex for fixing invalid JSON backslash escapes
+# ---------------------------------------------------------------------------
+
+# Fix invalid JSON backslash escapes (e.g. \_ or \L) that the model may
+# produce when echoing back input text. Valid JSON escapes
+# (" \\ \/ \b \f \n \r \t \uXXXX) are left untouched.
+_INVALID_ESCAPE_RE = re.compile(r'\\(?!["\\\//bfnrtu])')
+
+# ---------------------------------------------------------------------------
+# Regex for removing orphan keys in JSON objects
+# ---------------------------------------------------------------------------
+
+# Matches a bare string value inside a JSON object — a quoted string preceded
+# by a comma and followed by a comma or closing brace, but NOT followed by a
+# colon. Example: ,"pred_old", or ,"pred_old"}
+# The leading comma is consumed so removal doesn't leave double commas.
+_ORPHAN_KEY_AFTER_COMMA_RE = re.compile(r',\s*"[^"]*"\s*(?=\s*[,}])(?!\s*:)')
+# Orphan as the first item in an object: {"orphan","key":"val"}
+_ORPHAN_KEY_AT_START_RE = re.compile(r'(?<=\{)\s*"[^"]*"\s*,\s*(?=")')
+
+# Malformed array-object hybrid: ["":["R1"] -> ["R1"]
+_MALFORMED_ARRAY_OBJ_RE = re.compile(r'\[""\s*:\s*\[')
+
+# Missing opening quote on a JSON key: ,attribute_errors":[ -> ,"attribute_errors":[
+_MISSING_OPEN_QUOTE_RE = re.compile(
+ r'([\]}"0-9]\s*,\s*|{\s*)([a-zA-Z_][a-zA-Z0-9_]*")'
+)
+
+
+# ---------------------------------------------------------------------------
+# JSON response parsing
+# ---------------------------------------------------------------------------
+
+def _dedupe_keys(pairs):
+ """JSON object_pairs_hook that keeps the first occurrence of duplicate keys."""
+ result = {}
+ for key, value in pairs:
+ if key not in result:
+ result[key] = value
+ return result
+
+
+def _loads(text):
+ """json.loads with duplicate-key handling (keeps first occurrence)."""
+ return json.loads(text, object_pairs_hook=_dedupe_keys)
+
+
+def _is_structural_quote(text, i):
+ """Check if the quote at position *i* is likely a JSON structural delimiter.
+
+ Returns True for quotes that delimit JSON keys/values (after ``:``, before
+ ``:`` etc.) and False for quotes that are part of prose inside a string
+ value (e.g. the model quoting a phrase in an explanation).
+ """
+ before = text[i - 1] if i > 0 else ''
+ after = text[i + 1] if i + 1 < len(text) else ''
+ # Quotes right after : [ { are always structural (opening a value/key)
+ if before in (':', '[', '{'):
+ return True
+ # Quotes right before : ] } are always structural (closing a key/value)
+ if after in (':', ']', '}'):
+ return True
+ # For comma-adjacent quotes, check deeper context.
+ if before == ',':
+ pre_comma = text[i - 2] if i >= 2 else ''
+ return pre_comma in ('"', ']', '}') or pre_comma.isdigit()
+ if after == ',':
+ post_comma = text[i + 2] if i + 2 < len(text) else ''
+ return post_comma in ('"', '[', '{') or post_comma.isdigit()
+ return False
+
+
+def _fix_unescaped_quotes(text, max_attempts=50):
+ """Iteratively escape unescaped double-quotes that break JSON parsing.
+
+ When the model uses literal " inside a JSON string value (e.g. to quote
+ a phrase in an explanation), the JSON parser sees it as a string
+ terminator. This function locates each offending quote by parsing,
+ catching the error position, searching backwards for a non-structural
+ quote, escaping it, and retrying.
+ """
+ for _ in range(max_attempts):
+ try:
+ return _loads(text)
+ except json.JSONDecodeError as e:
+ pos = e.pos
+ search_start = max(0, pos - 5)
+ found = False
+ for i in range(pos, search_start - 1, -1):
+ if i < len(text) and text[i] == '"' and (i == 0 or text[i - 1] != '\\'):
+ if _is_structural_quote(text, i):
+ continue
+ text = text[:i] + '\\"' + text[i + 1:]
+ found = True
+ break
+ if not found:
+ raise
+ raise json.JSONDecodeError("Max quote-fix attempts reached", text, 0)
+
+
+def _fix_orphan_keys(text):
+ """Remove bare string values inside JSON objects (orphan keys with no value).
+
+ The model sometimes hallucinates partial key fragments (e.g. ``"pred_old"``)
+ that appear as bare values in objects, producing invalid JSON like::
+
+ {"ref_id":"R3","pred_old","pred_id":"P2"}
+
+ This function strips them so the JSON becomes parsable.
+ """
+ text = _ORPHAN_KEY_AFTER_COMMA_RE.sub('', text)
+ text = _ORPHAN_KEY_AT_START_RE.sub('', text)
+ return text
+
+
+def parse_json_response(response, batch_idx=None):
+ """Parse model response as JSON, applying progressive fixes.
+
+ Fix pipeline (each step only attempted if prior steps fail):
+ 1. Raw parse
+ 2. Remove orphan keys (bare strings in object context)
+ 3. Fix invalid backslash escapes (``\\L``, ``\\_``, etc.)
+ 4. Fix unescaped double-quotes inside string values
+
+ Args:
+ response: Raw model output string.
+ batch_idx: Optional index for error context in batch evaluation.
+
+ Returns:
+ Parsed JSON object (dict).
+
+ Raises:
+ ValueError: If the response cannot be parsed as JSON.
+ """
+ # 0. Pre-parse text fixes
+ # Escape curly/smart quotes — they always appear inside JSON string values
+ response = response.replace('\u201c', '\\"').replace('\u201d', '\\"')
+ response = response.replace('\u2018', "\\'").replace('\u2019', "\\'")
+ # Fix malformed array-object hybrids: ["":["R1"] -> ["R1"]
+ response = _MALFORMED_ARRAY_OBJ_RE.sub('[', response)
+ # Fix missing opening quote on keys: ,attribute_errors" -> ,"attribute_errors"
+ response = _MISSING_OPEN_QUOTE_RE.sub(r'\1"\2', response)
+ # 1. Raw
+ try:
+ return _loads(response)
+ except json.JSONDecodeError:
+ pass
+ # 2. Orphan keys (bare strings in object context, e.g. "pred_old")
+ deorphaned = _fix_orphan_keys(response)
+ try:
+ return _loads(deorphaned)
+ except json.JSONDecodeError:
+ pass
+ # 3. Invalid backslash escapes
+ escaped = _INVALID_ESCAPE_RE.sub(r'\\\\', response)
+ try:
+ return _loads(escaped)
+ except json.JSONDecodeError:
+ pass
+ # 4. Unescaped quotes (try on both raw and escape-fixed variants)
+ for text in (response, escaped):
+ try:
+ return _fix_unescaped_quotes(text)
+ except (json.JSONDecodeError, ValueError):
+ pass
+ ctx = f" for batch item {batch_idx}" if batch_idx is not None else ""
+ raise ValueError(
+ f"Failed to parse model response as JSON{ctx}\n"
+ f"Response ({len(response)} chars):\n{response}"
+ )