diff --git a/llm_fal_ai/models/fal_ai_provider.py b/llm_fal_ai/models/fal_ai_provider.py index 038e8523..bf5645f3 100644 --- a/llm_fal_ai/models/fal_ai_provider.py +++ b/llm_fal_ai/models/fal_ai_provider.py @@ -72,6 +72,8 @@ def fal_ai_generate(self, input_data, model=None, stream=False, **kwargs): input_data = ( json.loads(input_data) if isinstance(input_data, str) else input_data ) + input_data = self._fal_ai_resolve_inputs(input_data, model) + self._fal_ai_validate_inputs(input_data, model) try: if stream: @@ -145,6 +147,101 @@ def _fal_ai_generate_stream(self, client, model_name, input_data): _logger.error(f"Error in FAL AI stream: {e}") raise UserError(_(f"FAL AI streaming failed: {str(e)}")) from e + # Maps generic input key names (used by the AI assistant) to the fal.ai schema + # field variants tried in order. First match in the model's input schema wins. + _FAL_INPUT_FIELD_ALIASES = { + "image": ["image_url", "image"], + "audio": ["audio_url", "ref_audio_url", "audio"], + "video": ["video_url", "video"], + "mask": ["mask_url", "mask"], + # Voice cloning / TTS models use ref_audio_url for the reference sample. + "voice_sample": ["ref_audio_url", "audio_url", "voice_sample"], + "ref_audio": ["ref_audio_url", "audio_url", "ref_audio"], + } + + def _fal_ai_resolve_inputs(self, inputs, model): + """Resolve attachment IDs → data URIs and map generic keys to schema field names. + + The AI assistant passes inputs like {"image": }. This method: + 1. Looks up the ir.attachment record for any integer value. + 2. Converts it to a data URI (data:;base64,...). + 3. Renames the key to match the field name declared in the model's input schema + (e.g. "image" → "image_url" when "image_url" is in the schema). + """ + if not inputs or not isinstance(inputs, dict): + return inputs + + input_schema = self._fal_ai_get_input_schema(model) + schema_props = (input_schema or {}).get("properties", {}) + + resolved = {} + for key, value in inputs.items(): + # Resolve the target field name via schema aliases. + target_key = key + if key in self._FAL_INPUT_FIELD_ALIASES: + for candidate in self._FAL_INPUT_FIELD_ALIASES[key]: + if candidate in schema_props: + target_key = candidate + break + + # Resolve attachment references → data URI. + # Handles: integer 569, string "569", or string "attachment:569". + att_id = None + if isinstance(value, int) and value > 0: + att_id = value + elif isinstance(value, str): + raw = value.split(":", 1)[-1] if value.startswith("attachment:") else value + if raw.isdigit(): + att_id = int(raw) + + if att_id: + att = self.env["ir.attachment"].sudo().browse(att_id) + if att.exists() and att.datas: + mimetype = att.mimetype or "application/octet-stream" + resolved[target_key] = f"data:{mimetype};base64,{att.datas.decode()}" + _logger.info( + "fal_ai: resolved attachment %s (%s) → %s", + att_id, mimetype, target_key, + ) + continue + + resolved[target_key] = value + + return resolved + + def _fal_ai_get_input_schema(self, model): + """Return the normalized input schema, falling back to OpenAPI refs.""" + if not model: + return {} + + details = model.details or {} + input_schema = details.get("input_schema") + if input_schema: + return input_schema + + openapi_schema = details.get("openapi") + if not openapi_schema: + return {} + + extracted_input_schema, _output_schema = self._fal_ai_extract_openapi_io_schemas( + openapi_schema + ) + return extracted_input_schema or {} + + def _fal_ai_validate_inputs(self, inputs, model): + """Fail fast for missing required fal.ai fields before making the API call.""" + if not inputs or not isinstance(inputs, dict): + raise UserError(_("Generation inputs are required.")) + + input_schema = self._fal_ai_get_input_schema(model) + required = input_schema.get("required", []) if input_schema else [] + missing = [field for field in required if inputs.get(field) in (None, "")] + if missing: + raise UserError( + _("Missing required generation input(s): %s") + % ", ".join(sorted(missing)) + ) + def fal_ai_models(self, model_id=None): """Retrieve available Fal model endpoints from the platform API.""" self.ensure_one() @@ -197,8 +294,7 @@ def _fal_ai_fetch_models(self, model_id=None): _("Fal.ai model fetch failed: %s") % (e.reason or str(e)) ) from e - for raw_model in payload.get("models", []): - yield raw_model + yield from payload.get("models", []) if model_id: break @@ -224,15 +320,8 @@ def _fal_ai_parse_model(self, raw_model): openapi_schema = raw_model.get("openapi") if openapi_schema: - input_schema = ( - openapi_schema.get("components", {}) - .get("schemas", {}) - .get("Input") - ) - output_schema = ( - openapi_schema.get("components", {}) - .get("schemas", {}) - .get("Output") + input_schema, output_schema = self._fal_ai_extract_openapi_io_schemas( + openapi_schema ) if input_schema: details["input_schema"] = input_schema @@ -245,6 +334,66 @@ def _fal_ai_parse_model(self, raw_model): "details": details, } + def _fal_ai_extract_openapi_io_schemas(self, openapi_schema): + """Extract fal.ai input/output schemas from OpenAPI, including named refs.""" + schemas = openapi_schema.get("components", {}).get("schemas", {}) + input_schema = schemas.get("Input") + output_schema = schemas.get("Output") + + input_ref = self._fal_ai_find_request_body_ref(openapi_schema) + if input_ref: + input_schema = self._fal_ai_resolve_schema_ref(openapi_schema, input_ref) + + output_ref = self._fal_ai_find_success_response_ref(openapi_schema) + if output_ref: + output_schema = self._fal_ai_resolve_schema_ref(openapi_schema, output_ref) + + return input_schema, output_schema + + def _fal_ai_find_request_body_ref(self, openapi_schema): + for path_data in openapi_schema.get("paths", {}).values(): + for operation in path_data.values(): + if not isinstance(operation, dict): + continue + schema = ( + operation.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + ) + ref = schema.get("$ref") + if ref: + return ref + return None + + def _fal_ai_find_success_response_ref(self, openapi_schema): + for path_data in openapi_schema.get("paths", {}).values(): + for operation in path_data.values(): + if not isinstance(operation, dict): + continue + responses = operation.get("responses", {}) + for status in ("200", 200): + schema = ( + responses.get(status, {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + ) + ref = schema.get("$ref") + if ref and not ref.endswith("/QueueStatus"): + return ref + return None + + def _fal_ai_resolve_schema_ref(self, openapi_schema, ref): + prefix = "#/components/schemas/" + if not isinstance(ref, str) or not ref.startswith(prefix): + return None + return ( + openapi_schema.get("components", {}) + .get("schemas", {}) + .get(ref[len(prefix):]) + ) + def _fal_ai_capabilities_from_category(self, category, endpoint_id): """Map Fal model categories to Odoo provider capabilities.""" name = endpoint_id.lower() diff --git a/llm_fal_ai/tests/test_fal_ai_provider.py b/llm_fal_ai/tests/test_fal_ai_provider.py index c02daea4..f83a07ba 100644 --- a/llm_fal_ai/tests/test_fal_ai_provider.py +++ b/llm_fal_ai/tests/test_fal_ai_provider.py @@ -1,3 +1,4 @@ +from odoo.exceptions import UserError from odoo.tests.common import TransactionCase @@ -48,3 +49,147 @@ def test_parse_model_extracts_openapi_schemas(self): self.assertEqual(parsed["details"]["category"], "text-to-image") self.assertIn("input_schema", parsed["details"]) self.assertIn("output_schema", parsed["details"]) + + def test_parse_model_extracts_named_openapi_refs(self): + raw_model = { + "endpoint_id": "fal-ai/bytedance/seedance/v1/pro/image-to-video", + "metadata": { + "category": "image-to-video", + "description": "Image to video", + }, + "openapi": { + "paths": { + "/model": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SeedanceInput" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueueStatus" + } + } + } + } + }, + } + }, + "/model/requests/{request_id}": { + "get": { + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SeedanceOutput" + } + } + } + } + } + } + }, + }, + "components": { + "schemas": { + "QueueStatus": {"type": "object"}, + "SeedanceInput": { + "type": "object", + "required": ["prompt", "image_url"], + "properties": { + "prompt": {"type": "string"}, + "image_url": {"type": "string"}, + }, + }, + "SeedanceOutput": { + "type": "object", + "properties": {"video": {"type": "object"}}, + }, + } + }, + }, + } + + parsed = self.provider._fal_ai_parse_model(raw_model) + + input_schema = parsed["details"]["input_schema"] + output_schema = parsed["details"]["output_schema"] + self.assertEqual(input_schema["required"], ["prompt", "image_url"]) + self.assertIn("image_url", input_schema["properties"]) + self.assertIn("video", output_schema["properties"]) + + def test_resolve_inputs_uses_openapi_schema_fallback(self): + model = self.env["llm.model"].create({ + "name": "fal-ai/bytedance/seedance/v1/pro/image-to-video", + "provider_id": self.provider.id, + "model_use": "generation", + "details": { + "openapi": { + "paths": { + "/model": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SeedanceInput" + } + } + } + } + } + } + }, + "components": { + "schemas": { + "SeedanceInput": { + "type": "object", + "required": ["prompt", "image_url"], + "properties": { + "prompt": {"type": "string"}, + "image_url": {"type": "string"}, + }, + } + } + }, + } + }, + }) + + resolved = self.provider._fal_ai_resolve_inputs( + {"image": "https://example.com/image.jpg", "prompt": "make it move"}, + model, + ) + + self.assertEqual(resolved["image_url"], "https://example.com/image.jpg") + self.assertNotIn("image", resolved) + + def test_validate_inputs_fails_before_fal_api_when_prompt_missing(self): + model = self.env["llm.model"].create({ + "name": "fal-ai/bytedance/seedance/v1/pro/image-to-video", + "provider_id": self.provider.id, + "model_use": "generation", + "details": { + "input_schema": { + "type": "object", + "required": ["prompt", "image_url"], + "properties": { + "prompt": {"type": "string"}, + "image_url": {"type": "string"}, + }, + } + }, + }) + + with self.assertRaisesRegex(UserError, "prompt"): + self.provider._fal_ai_validate_inputs({"image_url": "x"}, model)