Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openverifiablellm/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .perplexity import PerplexityEvaluator

__all__ = [
"PerplexityEvaluator",
]
44 changes: 44 additions & 0 deletions openverifiablellm/eval/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import List

try:
from typing import Protocol, runtime_checkable
except ImportError: # Python < 3.8
from typing_extensions import Protocol, runtime_checkable


@runtime_checkable
class Model(Protocol):
"""Structural type for a language model callable."""

def __call__(self, input_ids: List[int]) -> List[List[float]]: ...


@runtime_checkable
class Tokenizer(Protocol):
"""Structural type for a tokenizer."""

def encode(self, text: str) -> List[int]: ...


class BaseEvaluator(ABC):
"""Abstract base class for all dataset evaluators."""

@abstractmethod
def evaluate(self, model: Model, tokenizer: Tokenizer) -> dict:
"""
Evaluate a language model using the given tokenizer.

Parameters
----------
model : callable
Callable accepting a sequence of token IDs and returning a
2-D sequence of logits with shape ``(len(input_ids), vocab_size)``.
tokenizer : object
Object with an ``encode(text: str) -> list[int]`` method.

Returns
-------
dict
Benchmark-specific evaluation results.
"""
5 changes: 5 additions & 0 deletions openverifiablellm/eval/factual/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .factual_consistency import WikipediaFactualEvaluator

__all__ = [
"WikipediaFactualEvaluator",
]
246 changes: 246 additions & 0 deletions openverifiablellm/eval/factual/factual_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
"""
openverifiablellm/eval/factual/factual_consistency.py

Wikipedia-based factual consistency evaluator.
"""

import math
import random
import re
from pathlib import Path
from typing import List, Optional, Union

from ..base import BaseEvaluator
from ..perplexity import PerplexityEvaluator

_ENTITY_RE = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b")


class WikipediaFactualEvaluator(BaseEvaluator):
"""
Evaluates factual consistency of a language model using Wikipedia passages.

For each sentence extracted from a processed Wikipedia text file
(``wiki_clean.txt``), a counterfactual variant is generated by substituting
a named entity found in the sentence with a different named entity drawn
from the same passage. The model's perplexity is then compared on the
original (factual) vs the substituted (counterfactual) sentence. A
well-trained model should assign lower perplexity to factual sentences.

The ``factual_score`` is the mean per-pair difference
``(counterfactual_ppl - factual_ppl)``: positive values indicate the model
correctly prefers factual sentences, negative values indicate the model
prefers counterfactual sentences.

Named entities are identified with the simple capitalized-word-sequence
regex ``r"\\b([A-Z][a-z]+(?:\\s+[A-Z][a-z]+)*)\\b"``. Evaluation is
fully deterministic: ``random.seed(42)`` is applied inside
:meth:`evaluate` before any entity selection.

Parameters
----------
wiki_text_path : str or Path
Path to the processed ``wiki_clean.txt`` file produced by
:func:`openverifiablellm.utils.extract_text_from_xml`.
n_samples : int or None
Maximum number of sentence pairs to evaluate. ``None`` evaluates
all available pairs. Default ``None``.
"""

def __init__(
self,
wiki_text_path: Union[str, Path],
n_samples: Optional[int] = None,
):
self.wiki_text_path = Path(wiki_text_path)
self.n_samples = n_samples

# ------------------------------------------------------------------
# Static helpers
# ------------------------------------------------------------------

@staticmethod
def _substitute_entity(sentence: str, candidate_entities: List[str]) -> Optional[str]:
"""
Replace the first named entity in *sentence* with a random different
entity drawn from *candidate_entities*.

Named entities are matched by
``r"\\b([A-Z][a-z]+(?:\\s+[A-Z][a-z]+)*)\\b"``.

Parameters
----------
sentence : str
Input sentence.
candidate_entities : list[str]
Pool of named entities to draw substitutes from (typically all
entities extracted from the enclosing passage).

Returns
-------
str or None
The modified sentence with the first entity replaced, or ``None``
if no named entity was found in *sentence* or no differing
substitute is available in *candidate_entities*.
"""
matches = _ENTITY_RE.findall(sentence)
if not matches:
return None

found_entity = matches[0]
alternatives = [e for e in candidate_entities if e != found_entity]
if not alternatives:
return None

substitute = random.choice(alternatives)
pattern = r"\b" + re.escape(found_entity) + r"\b"
return re.sub(pattern, substitute, sentence, count=1)

@staticmethod
def _extract_passages(
wiki_text_path: Union[str, Path],
n_samples: Optional[int],
) -> List[dict]:
"""
Build factual/counterfactual sentence pairs from *wiki_text_path*.

The file is read line by line; consecutive non-empty lines are grouped
into passages (blank lines act as separators). For each passage the
lines are joined into a single string, split on ``". "``, and each
resulting sentence is tested for entity substitution via
:meth:`_substitute_entity`. A pair is emitted for every sentence that
yields a valid counterfactual. Collection halts early once *n_samples*
pairs have been gathered (if *n_samples* is not ``None``).

Parameters
----------
wiki_text_path : str or Path
Path to the processed ``wiki_clean.txt`` file.
n_samples : int or None
Maximum number of pairs to return.

Returns
-------
list[dict]
Each element is ``{"original": str, "counterfactual": str}``.
"""
pairs: List[dict] = []
current_lines: List[str] = []

def _process_passage(lines: List[str]) -> None:
passage_text = " ".join(lines)
all_entities = _ENTITY_RE.findall(passage_text)
if not all_entities:
return
sentences = passage_text.split(". ")
for sentence in sentences:
if n_samples is not None and len(pairs) >= n_samples:
return
sentence = sentence.strip()
if not sentence:
continue
counterfactual = WikipediaFactualEvaluator._substitute_entity(
sentence, all_entities
)
if counterfactual is not None and counterfactual != sentence:
pairs.append({"original": sentence, "counterfactual": counterfactual})

with open(wiki_text_path, encoding="utf-8") as fh:
for raw_line in fh:
line = raw_line.rstrip("\n")
if line.strip():
current_lines.append(line.strip())
else:
if current_lines:
_process_passage(current_lines)
current_lines = []
if n_samples is not None and len(pairs) >= n_samples:
return pairs

# Handle final passage if file has no trailing blank line
if current_lines:
_process_passage(current_lines)

return pairs

# ------------------------------------------------------------------
# BaseEvaluator interface
# ------------------------------------------------------------------

def evaluate(self, model, tokenizer) -> dict:
"""
Compute factual consistency scores for *model*.

Extracts sentence pairs from the configured Wikipedia text file, then
computes perplexity for each original and counterfactual sentence using
the same teacher-forced method as
:class:`~openverifiablellm.eval.perplexity.PerplexityEvaluator`.

``random.seed(42)`` is applied before any entity selection to ensure
fully reproducible results.

Parameters
----------
model : callable
``model(input_ids) -> 2-D sequence`` of shape
``(len(input_ids), vocab_size)``, as described in
:meth:`~openverifiablellm.eval.perplexity.PerplexityEvaluator.compute_sentence_perplexity`.
tokenizer : object
Object with ``encode(text: str) -> list[int]``.

Returns
-------
dict
A dictionary with the following keys:

* **factual_perplexity** (*float*) — mean perplexity on original
sentences.
* **counterfactual_perplexity** (*float*) — mean perplexity on
counterfactual sentences.
* **factual_score** (*float*) — mean per-pair difference
``(counterfactual_ppl - factual_ppl)``; positive means the model
correctly prefers factual sentences (good), negative means the
model prefers counterfactual sentences (bad).
"""
random.seed(42)
pairs = self._extract_passages(self.wiki_text_path, self.n_samples)
Comment on lines +205 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid mutating global RNG state in evaluate().

random.seed(42) resets process-wide randomness and can affect unrelated code paths. Use a local random.Random(42) instance and thread it through substitution logic.

Proposed fix
-    `@staticmethod`
-    def _substitute_entity(sentence: str, candidate_entities: List[str]) -> Optional[str]:
+    `@staticmethod`
+    def _substitute_entity(
+        sentence: str, candidate_entities: List[str], rng: random.Random
+    ) -> Optional[str]:
@@
-        substitute = random.choice(alternatives)
+        substitute = rng.choice(alternatives)
@@
-    def _extract_passages(
+    def _extract_passages(
         wiki_text_path: Union[str, Path],
         n_samples: Optional[int],
+        rng: random.Random,
     ) -> List[dict]:
@@
-                counterfactual = WikipediaFactualEvaluator._substitute_entity(
-                    sentence, all_entities
-                )
+                counterfactual = WikipediaFactualEvaluator._substitute_entity(
+                    sentence, all_entities, rng
+                )
@@
-        random.seed(42)
-        pairs = self._extract_passages(self.wiki_text_path, self.n_samples)
+        rng = random.Random(42)
+        pairs = self._extract_passages(self.wiki_text_path, self.n_samples, rng)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@openverifiablellm/eval/factual/factual_consistency.py` around lines 205 -
206, The evaluate() function currently calls random.seed(42) which mutates
global RNG; instead create a local RNG like rng = random.Random(42), remove the
global seed call, and pass this rng into _extract_passages and any downstream
substitution logic so all random operations use rng (e.g., replace
random.sample/random.choice calls inside _extract_passages, substitution
functions, or helper methods with rng.sample/rng.choice). Update signatures for
_extract_passages and any helper functions to accept an rng parameter
(defaulting to None or Random() if needed) and thread it through to ensure
determinism without touching global state.


if not pairs:
return {
"factual_perplexity": float("inf"),
"counterfactual_perplexity": float("inf"),
"factual_score": float("inf"),
}

factual_ppls: List[float] = []
counterfactual_ppls: List[float] = []
score_diffs: List[float] = []

for pair in pairs:
factual_tokens = tokenizer.encode(pair["original"])
cf_tokens = tokenizer.encode(pair["counterfactual"])

factual_ppl = PerplexityEvaluator.compute_sentence_perplexity(
model, factual_tokens
)
cf_ppl = PerplexityEvaluator.compute_sentence_perplexity(model, cf_tokens)

if not math.isfinite(factual_ppl) or not math.isfinite(cf_ppl):
continue

factual_ppls.append(factual_ppl)
counterfactual_ppls.append(cf_ppl)
score_diffs.append(cf_ppl - factual_ppl)

n = len(factual_ppls)
if n == 0:
return {
"factual_perplexity": float("nan"),
"counterfactual_perplexity": float("nan"),
"factual_score": float("nan"),
}
return {
"factual_perplexity": sum(factual_ppls) / n,
"counterfactual_perplexity": sum(counterfactual_ppls) / n,
"factual_score": sum(score_diffs) / n,
}
Loading
Loading