diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index e9407b53..c65003e7 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -251,6 +251,11 @@ async def _generate_from_intrinsic( if not ctx.is_chat_context: raise Exception("Does not yet support non-chat contexts.") + if len(model_options.items()) > 0: + FancyLogger.get_logger().info( + "passing in model options when generating with an adapter; some model options may be overwritten / ignored" + ) + linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None, ( "If ctx.is_chat_context, then the context should be linearizable." @@ -311,6 +316,12 @@ async def _generate_from_intrinsic( "messages": conversation, "extra_body": {"documents": docs}, } + + # Convert other parameters from Mellea proprietary format to standard format. + for model_option in model_options: + if model_option == ModelOption.TEMPERATURE: + request_json["temperature"] = model_options[model_option] + rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs) # TODO: Handle caching here. granite_common doesn't tell us what changed, diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 7d1c70ab..ba825753 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -435,6 +435,12 @@ async def _generate_from_intrinsic( "extra_body": {"documents": docs}, } + # Convert other parameters from Mellea proprietary format to standard format. + if model_options is not None: + for model_option in model_options: + if model_option == ModelOption.TEMPERATURE: + request_json["temperature"] = model_options[model_option] + rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs) self.load_adapter(adapter.qualified_name) diff --git a/mellea/stdlib/intrinsics/rag.py b/mellea/stdlib/intrinsics/rag.py index 5885c900..80f7a999 100644 --- a/mellea/stdlib/intrinsics/rag.py +++ b/mellea/stdlib/intrinsics/rag.py @@ -9,6 +9,7 @@ AdapterType, GraniteCommonAdapter, ) +from mellea.backends.types import ModelOption from mellea.stdlib.base import ChatContext, Document from mellea.stdlib.chat import Message from mellea.stdlib.intrinsics.intrinsic import Intrinsic @@ -63,6 +64,7 @@ def _call_intrinsic( intrinsic, context, backend, + model_options={ModelOption.TEMPERATURE: 0.0}, # No rejection sampling, please strategy=None, ) @@ -277,7 +279,7 @@ def rewrite_answer_for_relevance( backend, kwargs={ "answer_relevance_category": result_json["answer_relevance_category"], - "answer_relevance_analysis": result_json["answer_relevance_category"], + "answer_relevance_analysis": result_json["answer_relevance_analysis"], "correction_method": correction_method, }, ) diff --git a/test/stdlib_intrinsics/test_rag/test_rag.py b/test/stdlib_intrinsics/test_rag/test_rag.py index e2477cf3..016b78a0 100644 --- a/test/stdlib_intrinsics/test_rag/test_rag.py +++ b/test/stdlib_intrinsics/test_rag/test_rag.py @@ -161,7 +161,13 @@ def test_hallucination_detection(backend): def test_answer_relevance(backend): """Verify that the answer relevance composite intrinsic functions properly.""" context, answer, docs = _read_input_json("answer_relevance.json") - expected_rewrite = "Alice, Bob, and Carol attended the meeting." + + # Note that this is not the optimal answer. This test is currently using an + # outdated LoRA adapter. Releases of new adapters will come after the Mellea + # integration has stabilized. + expected_rewrite = ( + "The documents do not provide information about the attendees of the meeting." + ) # First call triggers adapter loading result = rag.rewrite_answer_for_relevance(answer, docs, context, backend) @@ -178,5 +184,17 @@ def test_answer_relevance(backend): assert result == answer +def test_answer_relevance_classifier(backend): + """Verify that the first phase of the answer relevance flow behaves as expectee.""" + context, answer, docs = _read_input_json("answer_relevance.json") + + result_json = rag._call_intrinsic( + "answer_relevance_classifier", + context.add(Message("assistant", answer, documents=list(docs))), + backend, + ) + assert result_json["answer_relevance_likelihood"] == 0.0 + + if __name__ == "__main__": pytest.main([__file__])