From d3d5680c8ba2e0ace1045fc8d89666f6caec4145 Mon Sep 17 00:00:00 2001 From: MohammedSB Date: Sat, 4 Apr 2026 01:55:57 -0400 Subject: [PATCH] Update CRIMSON metric to align with up-to-date CRIMSON implementation - Match HF model init: padding_side="left", generation_config detection, flash attention 2 support, repetition_penalty=1.1 - Auto-disable guidelines for default MedGemma HF model (trained without them) - Port robust JSON parsing from standalone (handles orphan keys, bad escapes, unescaped quotes, smart quotes, malformed arrays) - Fix scoring fallback defaults to match standalone (0.25 for unknown significance instead of 0.0) - Update default HF model to rajpurkarlab/medgemma-4b-it-crimson - Add cache_dir parameter for controlling HF model cache location - Add CRIMSON to README acknowledgments I also saw that you added the repetition_penalty=1.1, which is good for helping the model not get stuck in the self-repetitive loop, which it sometimes does. This is especially true if the inputs are not formatting as expected. Do you have any tips on how to improve fine-tuning to avoid this? --- README.md | 2 +- RadEval/RadEval.py | 3 + RadEval/metrics/crimson/crimson.py | 52 ++++++--- RadEval/metrics/crimson/utils.py | 180 +++++++++++++++++++++++++++++ 4 files changed, 222 insertions(+), 15 deletions(-) create mode 100644 RadEval/metrics/crimson/utils.py diff --git a/README.md b/README.md index c7b71d3..cfd626d 100755 --- a/README.md +++ b/README.md @@ -234,7 +234,7 @@ A curated evaluation set annotated by board-certified radiologists for validatin ## Acknowledgments -Built on the work of the radiology AI community: [CheXbert](https://github.com/stanfordmlgroup/CheXbert), [RadGraph](https://github.com/jbdel/RadGraph), [BERTScore](https://github.com/Tiiiger/bert_score), [RaTEScore](https://github.com/MAGIC-AI4Med/RaTEScore), [SRR-BERT](https://github.com/StanfordAIMI/SRR-BERT), [GREEN](https://github.com/Stanford-AIMI/GREEN), and datasets like [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/). +Built on the work of the radiology AI community: [CheXbert](https://github.com/stanfordmlgroup/CheXbert), [RadGraph](https://github.com/jbdel/RadGraph), [BERTScore](https://github.com/Tiiiger/bert_score), [RaTEScore](https://github.com/MAGIC-AI4Med/RaTEScore), [SRR-BERT](https://github.com/StanfordAIMI/SRR-BERT), [GREEN](https://github.com/Stanford-AIMI/GREEN), [CRIMSON](https://github.com/rajpurkarlab/CRIMSON), and datasets like [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/). ---
diff --git a/RadEval/RadEval.py b/RadEval/RadEval.py index 2d2eb47..1bab4dc 100755 --- a/RadEval/RadEval.py +++ b/RadEval/RadEval.py @@ -38,6 +38,7 @@ def __init__(self, crimson_model=None, crimson_batch_size=1, crimson_max_concurrent=50, + cache_dir=None, hoppr_crimson_ct_api="openai", hoppr_crimson_ct_model=None, do_per_sample=False, @@ -76,6 +77,7 @@ def __init__(self, self.crimson_model = crimson_model self.crimson_batch_size = crimson_batch_size self.crimson_max_concurrent = crimson_max_concurrent + self.cache_dir = cache_dir self.hoppr_crimson_ct_api = hoppr_crimson_ct_api self.hoppr_crimson_ct_model = hoppr_crimson_ct_model self.do_radeval_bertscore = do_radeval_bertscore @@ -184,6 +186,7 @@ def __init__(self, gemini_api_key=self.gemini_api_key, batch_size=self.crimson_batch_size, max_concurrent=self.crimson_max_concurrent, + cache_dir=self.cache_dir, ) except (ImportError, EnvironmentError, OSError) as e: warnings.warn( diff --git a/RadEval/metrics/crimson/crimson.py b/RadEval/metrics/crimson/crimson.py index ea76dc7..aa735ce 100644 --- a/RadEval/metrics/crimson/crimson.py +++ b/RadEval/metrics/crimson/crimson.py @@ -21,6 +21,7 @@ from .._llm_base import LLMMetricBase from .prompt_parts import build_prompt as _build_evaluation_prompt_fn +from .utils import parse_json_response as _parse_json_response_robust logger = logging.getLogger(__name__) @@ -113,8 +114,9 @@ class CRIMSONScore(LLMMetricBase): SUPPORTED_PROVIDERS: ClassVar[set[str]] = {"openai", "hf"} - DEFAULT_HF_MODEL = "CRIMSONScore/medgemma-4b-it-crimson" + DEFAULT_HF_MODEL = "rajpurkarlab/medgemma-4b-it-crimson" DEFAULT_OPENAI_MODEL = "gpt-5.2" + DEFAULT_MAX_NEW_TOKENS = 8192 def __init__( self, @@ -125,6 +127,7 @@ def __init__( device=None, batch_size=1, max_concurrent=50, + cache_dir=None, ): resolved_model = model_name or ( self.DEFAULT_OPENAI_MODEL if provider == "openai" @@ -140,6 +143,7 @@ def __init__( ) self.batch_size = batch_size + self.cache_dir = cache_dir if provider in ("huggingface", "hf"): self._init_hf_pipeline() @@ -159,6 +163,8 @@ def _init_hf_pipeline(self): logger.info("Using Flash Attention 2.") except ImportError: pass + if self.cache_dir: + model_kwargs["cache_dir"] = self.cache_dir self.pipe = transformers.pipeline( "text-generation", @@ -166,6 +172,12 @@ def _init_hf_pipeline(self): model_kwargs=model_kwargs, device_map="auto", ) + # Left-pad for efficient batch generation with decoder-only models + if self.pipe.tokenizer.padding_side != "left": + self.pipe.tokenizer.padding_side = "left" + # If the model has a generation_config.json, use it instead of + # injecting our own generation kwargs at inference time. + self._has_generation_config = not self.pipe.model.generation_config._from_model_config logger.info("Model loaded.") # ------------------------------------------------------------------ @@ -194,8 +206,8 @@ def _build_request(self, ref: str, hyp: str, **kwargs) -> dict[str, Any]: def _parse_response(self, raw: str) -> dict: cleaned = _extract_json_str(raw) try: - evaluation = json.loads(cleaned) - except json.JSONDecodeError: + evaluation = _parse_json_response_robust(cleaned) + except (ValueError, json.JSONDecodeError): repaired = _repair_truncated_json(cleaned) if repaired is not None: logger.warning("Repaired truncated CRIMSON JSON.") @@ -250,14 +262,23 @@ def _chat_completion(self, request: dict[str, Any]) -> str: def _hf_generate(self, prompt: str) -> str: messages = [ {"role": "system", "content": _SYSTEM_MSG}, - {"role": "user", "content": prompt + "\nPlease respond with valid JSON only."}, + {"role": "user", "content": prompt}, ] - outputs = self.pipe( - messages, - max_new_tokens=8192, - do_sample=False, - repetition_penalty=1.1, - ) + + if self._has_generation_config: + outputs = self.pipe( + messages, + generation_config=self.pipe.model.generation_config, + repetition_penalty=1.1, + ) + else: + outputs = self.pipe( + messages, + max_new_tokens=self.DEFAULT_MAX_NEW_TOKENS, + max_length=None, + do_sample=False, + repetition_penalty=1.1, + ) response = outputs[0]["generated_text"][-1]["content"] if not response: logger.warning("Empty response from HF pipeline. Raw: %s", outputs) @@ -385,6 +406,9 @@ def _nan_fallback(): def _build_evaluation_prompt(self, reference_findings, predicted_findings, patient_context=None, include_guidelines=True): + # For the MedGemmaCRIMSON, exclude guidelines. The model was trained without them. + if self.provider in ("huggingface", "hf") and self.model_name == self.DEFAULT_HF_MODEL: + include_guidelines = False return _build_evaluation_prompt_fn( reference_findings, predicted_findings, @@ -414,15 +438,15 @@ def _calculate_crimson(self, evaluation): def calculate_weighted_count(error_list, weights=significance_weights, key="clinical_significance"): - return sum(weights.get(error.get(key, "benign_expected"), 0.0) + return sum(weights.get(error.get(key, ""), 0.25) for error in error_list) ref_weight_by_id = { - ref["id"]: significance_weights.get(ref.get("clinical_significance", "benign_expected"), 0.0) + ref["id"]: significance_weights.get(ref.get("clinical_significance", ""), 0.25) for ref in reference_findings_list } pred_weight_by_id = { - pred["id"]: significance_weights.get(pred.get("clinical_significance", "benign_expected"), 0.0) + pred["id"]: significance_weights.get(pred.get("clinical_significance", ""), 0.25) for pred in predicted_findings_list } @@ -467,7 +491,7 @@ def calculate_weighted_count(error_list, weights=significance_weights, correct += base_weight else: sum_error_weights = sum( - attribute_severity_weights.get(err.get("severity", "negligible"), 0.0) + attribute_severity_weights.get(err.get("severity", ""), 0.25) for err in finding_attr_errors) denom = base_weight + sum_error_weights credit_factor = base_weight / denom if denom > 0 else 0.0 diff --git a/RadEval/metrics/crimson/utils.py b/RadEval/metrics/crimson/utils.py new file mode 100644 index 0000000..2460f50 --- /dev/null +++ b/RadEval/metrics/crimson/utils.py @@ -0,0 +1,180 @@ +"""JSON parsing utilities for CRIMSON scoring. + +Ported from the standalone CRIMSON package to provide robust JSON parsing +for LLM responses that may contain malformed JSON. +""" + +import json +import re + +# --------------------------------------------------------------------------- +# Regex for fixing invalid JSON backslash escapes +# --------------------------------------------------------------------------- + +# Fix invalid JSON backslash escapes (e.g. \_ or \L) that the model may +# produce when echoing back input text. Valid JSON escapes +# (" \\ \/ \b \f \n \r \t \uXXXX) are left untouched. +_INVALID_ESCAPE_RE = re.compile(r'\\(?!["\\\//bfnrtu])') + +# --------------------------------------------------------------------------- +# Regex for removing orphan keys in JSON objects +# --------------------------------------------------------------------------- + +# Matches a bare string value inside a JSON object — a quoted string preceded +# by a comma and followed by a comma or closing brace, but NOT followed by a +# colon. Example: ,"pred_old", or ,"pred_old"} +# The leading comma is consumed so removal doesn't leave double commas. +_ORPHAN_KEY_AFTER_COMMA_RE = re.compile(r',\s*"[^"]*"\s*(?=\s*[,}])(?!\s*:)') +# Orphan as the first item in an object: {"orphan","key":"val"} +_ORPHAN_KEY_AT_START_RE = re.compile(r'(?<=\{)\s*"[^"]*"\s*,\s*(?=")') + +# Malformed array-object hybrid: ["":["R1"] -> ["R1"] +_MALFORMED_ARRAY_OBJ_RE = re.compile(r'\[""\s*:\s*\[') + +# Missing opening quote on a JSON key: ,attribute_errors":[ -> ,"attribute_errors":[ +_MISSING_OPEN_QUOTE_RE = re.compile( + r'([\]}"0-9]\s*,\s*|{\s*)([a-zA-Z_][a-zA-Z0-9_]*")' +) + + +# --------------------------------------------------------------------------- +# JSON response parsing +# --------------------------------------------------------------------------- + +def _dedupe_keys(pairs): + """JSON object_pairs_hook that keeps the first occurrence of duplicate keys.""" + result = {} + for key, value in pairs: + if key not in result: + result[key] = value + return result + + +def _loads(text): + """json.loads with duplicate-key handling (keeps first occurrence).""" + return json.loads(text, object_pairs_hook=_dedupe_keys) + + +def _is_structural_quote(text, i): + """Check if the quote at position *i* is likely a JSON structural delimiter. + + Returns True for quotes that delimit JSON keys/values (after ``:``, before + ``:`` etc.) and False for quotes that are part of prose inside a string + value (e.g. the model quoting a phrase in an explanation). + """ + before = text[i - 1] if i > 0 else '' + after = text[i + 1] if i + 1 < len(text) else '' + # Quotes right after : [ { are always structural (opening a value/key) + if before in (':', '[', '{'): + return True + # Quotes right before : ] } are always structural (closing a key/value) + if after in (':', ']', '}'): + return True + # For comma-adjacent quotes, check deeper context. + if before == ',': + pre_comma = text[i - 2] if i >= 2 else '' + return pre_comma in ('"', ']', '}') or pre_comma.isdigit() + if after == ',': + post_comma = text[i + 2] if i + 2 < len(text) else '' + return post_comma in ('"', '[', '{') or post_comma.isdigit() + return False + + +def _fix_unescaped_quotes(text, max_attempts=50): + """Iteratively escape unescaped double-quotes that break JSON parsing. + + When the model uses literal " inside a JSON string value (e.g. to quote + a phrase in an explanation), the JSON parser sees it as a string + terminator. This function locates each offending quote by parsing, + catching the error position, searching backwards for a non-structural + quote, escaping it, and retrying. + """ + for _ in range(max_attempts): + try: + return _loads(text) + except json.JSONDecodeError as e: + pos = e.pos + search_start = max(0, pos - 5) + found = False + for i in range(pos, search_start - 1, -1): + if i < len(text) and text[i] == '"' and (i == 0 or text[i - 1] != '\\'): + if _is_structural_quote(text, i): + continue + text = text[:i] + '\\"' + text[i + 1:] + found = True + break + if not found: + raise + raise json.JSONDecodeError("Max quote-fix attempts reached", text, 0) + + +def _fix_orphan_keys(text): + """Remove bare string values inside JSON objects (orphan keys with no value). + + The model sometimes hallucinates partial key fragments (e.g. ``"pred_old"``) + that appear as bare values in objects, producing invalid JSON like:: + + {"ref_id":"R3","pred_old","pred_id":"P2"} + + This function strips them so the JSON becomes parsable. + """ + text = _ORPHAN_KEY_AFTER_COMMA_RE.sub('', text) + text = _ORPHAN_KEY_AT_START_RE.sub('', text) + return text + + +def parse_json_response(response, batch_idx=None): + """Parse model response as JSON, applying progressive fixes. + + Fix pipeline (each step only attempted if prior steps fail): + 1. Raw parse + 2. Remove orphan keys (bare strings in object context) + 3. Fix invalid backslash escapes (``\\L``, ``\\_``, etc.) + 4. Fix unescaped double-quotes inside string values + + Args: + response: Raw model output string. + batch_idx: Optional index for error context in batch evaluation. + + Returns: + Parsed JSON object (dict). + + Raises: + ValueError: If the response cannot be parsed as JSON. + """ + # 0. Pre-parse text fixes + # Escape curly/smart quotes — they always appear inside JSON string values + response = response.replace('\u201c', '\\"').replace('\u201d', '\\"') + response = response.replace('\u2018', "\\'").replace('\u2019', "\\'") + # Fix malformed array-object hybrids: ["":["R1"] -> ["R1"] + response = _MALFORMED_ARRAY_OBJ_RE.sub('[', response) + # Fix missing opening quote on keys: ,attribute_errors" -> ,"attribute_errors" + response = _MISSING_OPEN_QUOTE_RE.sub(r'\1"\2', response) + # 1. Raw + try: + return _loads(response) + except json.JSONDecodeError: + pass + # 2. Orphan keys (bare strings in object context, e.g. "pred_old") + deorphaned = _fix_orphan_keys(response) + try: + return _loads(deorphaned) + except json.JSONDecodeError: + pass + # 3. Invalid backslash escapes + escaped = _INVALID_ESCAPE_RE.sub(r'\\\\', response) + try: + return _loads(escaped) + except json.JSONDecodeError: + pass + # 4. Unescaped quotes (try on both raw and escape-fixed variants) + for text in (response, escaped): + try: + return _fix_unescaped_quotes(text) + except (json.JSONDecodeError, ValueError): + pass + ctx = f" for batch item {batch_idx}" if batch_idx is not None else "" + raise ValueError( + f"Failed to parse model response as JSON{ctx}\n" + f"Response ({len(response)} chars):\n{response}" + )