Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 160 additions & 11 deletions llm_fal_ai/models/fal_ai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def fal_ai_generate(self, input_data, model=None, stream=False, **kwargs):
input_data = (
json.loads(input_data) if isinstance(input_data, str) else input_data
)
input_data = self._fal_ai_resolve_inputs(input_data, model)
self._fal_ai_validate_inputs(input_data, model)

try:
if stream:
Expand Down Expand Up @@ -145,6 +147,101 @@ def _fal_ai_generate_stream(self, client, model_name, input_data):
_logger.error(f"Error in FAL AI stream: {e}")
raise UserError(_(f"FAL AI streaming failed: {str(e)}")) from e

# Maps generic input key names (used by the AI assistant) to the fal.ai schema
# field variants tried in order. First match in the model's input schema wins.
_FAL_INPUT_FIELD_ALIASES = {
"image": ["image_url", "image"],
"audio": ["audio_url", "ref_audio_url", "audio"],
"video": ["video_url", "video"],
"mask": ["mask_url", "mask"],
# Voice cloning / TTS models use ref_audio_url for the reference sample.
"voice_sample": ["ref_audio_url", "audio_url", "voice_sample"],
"ref_audio": ["ref_audio_url", "audio_url", "ref_audio"],
}

def _fal_ai_resolve_inputs(self, inputs, model):
"""Resolve attachment IDs → data URIs and map generic keys to schema field names.

The AI assistant passes inputs like {"image": <attachment_id>}. This method:
1. Looks up the ir.attachment record for any integer value.
2. Converts it to a data URI (data:<mimetype>;base64,...).
3. Renames the key to match the field name declared in the model's input schema
(e.g. "image" → "image_url" when "image_url" is in the schema).
"""
if not inputs or not isinstance(inputs, dict):
return inputs

input_schema = self._fal_ai_get_input_schema(model)
schema_props = (input_schema or {}).get("properties", {})

resolved = {}
for key, value in inputs.items():
# Resolve the target field name via schema aliases.
target_key = key
if key in self._FAL_INPUT_FIELD_ALIASES:
for candidate in self._FAL_INPUT_FIELD_ALIASES[key]:
if candidate in schema_props:
target_key = candidate
break

# Resolve attachment references → data URI.
# Handles: integer 569, string "569", or string "attachment:569".
att_id = None
if isinstance(value, int) and value > 0:
att_id = value
elif isinstance(value, str):
raw = value.split(":", 1)[-1] if value.startswith("attachment:") else value
if raw.isdigit():
att_id = int(raw)

if att_id:
att = self.env["ir.attachment"].sudo().browse(att_id)
if att.exists() and att.datas:
mimetype = att.mimetype or "application/octet-stream"
resolved[target_key] = f"data:{mimetype};base64,{att.datas.decode()}"
_logger.info(
"fal_ai: resolved attachment %s (%s) → %s",
att_id, mimetype, target_key,
)
continue

resolved[target_key] = value

return resolved

def _fal_ai_get_input_schema(self, model):
"""Return the normalized input schema, falling back to OpenAPI refs."""
if not model:
return {}

details = model.details or {}
input_schema = details.get("input_schema")
if input_schema:
return input_schema

openapi_schema = details.get("openapi")
if not openapi_schema:
return {}

extracted_input_schema, _output_schema = self._fal_ai_extract_openapi_io_schemas(
openapi_schema
)
return extracted_input_schema or {}

def _fal_ai_validate_inputs(self, inputs, model):
"""Fail fast for missing required fal.ai fields before making the API call."""
if not inputs or not isinstance(inputs, dict):
raise UserError(_("Generation inputs are required."))

input_schema = self._fal_ai_get_input_schema(model)
required = input_schema.get("required", []) if input_schema else []
missing = [field for field in required if inputs.get(field) in (None, "")]
if missing:
raise UserError(
_("Missing required generation input(s): %s")
% ", ".join(sorted(missing))
)

def fal_ai_models(self, model_id=None):
"""Retrieve available Fal model endpoints from the platform API."""
self.ensure_one()
Expand Down Expand Up @@ -197,8 +294,7 @@ def _fal_ai_fetch_models(self, model_id=None):
_("Fal.ai model fetch failed: %s") % (e.reason or str(e))
) from e

for raw_model in payload.get("models", []):
yield raw_model
yield from payload.get("models", [])

if model_id:
break
Expand All @@ -224,15 +320,8 @@ def _fal_ai_parse_model(self, raw_model):

openapi_schema = raw_model.get("openapi")
if openapi_schema:
input_schema = (
openapi_schema.get("components", {})
.get("schemas", {})
.get("Input")
)
output_schema = (
openapi_schema.get("components", {})
.get("schemas", {})
.get("Output")
input_schema, output_schema = self._fal_ai_extract_openapi_io_schemas(
openapi_schema
)
if input_schema:
details["input_schema"] = input_schema
Expand All @@ -245,6 +334,66 @@ def _fal_ai_parse_model(self, raw_model):
"details": details,
}

def _fal_ai_extract_openapi_io_schemas(self, openapi_schema):
"""Extract fal.ai input/output schemas from OpenAPI, including named refs."""
schemas = openapi_schema.get("components", {}).get("schemas", {})
input_schema = schemas.get("Input")
output_schema = schemas.get("Output")

input_ref = self._fal_ai_find_request_body_ref(openapi_schema)
if input_ref:
input_schema = self._fal_ai_resolve_schema_ref(openapi_schema, input_ref)

output_ref = self._fal_ai_find_success_response_ref(openapi_schema)
if output_ref:
output_schema = self._fal_ai_resolve_schema_ref(openapi_schema, output_ref)

return input_schema, output_schema

def _fal_ai_find_request_body_ref(self, openapi_schema):
for path_data in openapi_schema.get("paths", {}).values():
for operation in path_data.values():
if not isinstance(operation, dict):
continue
schema = (
operation.get("requestBody", {})
.get("content", {})
.get("application/json", {})
.get("schema", {})
)
ref = schema.get("$ref")
if ref:
return ref
return None

def _fal_ai_find_success_response_ref(self, openapi_schema):
for path_data in openapi_schema.get("paths", {}).values():
for operation in path_data.values():
if not isinstance(operation, dict):
continue
responses = operation.get("responses", {})
for status in ("200", 200):
schema = (
responses.get(status, {})
.get("content", {})
.get("application/json", {})
.get("schema", {})
)
ref = schema.get("$ref")
if ref and not ref.endswith("/QueueStatus"):
return ref
return None

def _fal_ai_resolve_schema_ref(self, openapi_schema, ref):
prefix = "#/components/schemas/"
if not isinstance(ref, str) or not ref.startswith(prefix):
return None
return (
openapi_schema.get("components", {})
.get("schemas", {})
.get(ref[len(prefix):])
)

def _fal_ai_capabilities_from_category(self, category, endpoint_id):
"""Map Fal model categories to Odoo provider capabilities."""
name = endpoint_id.lower()
Expand Down
145 changes: 145 additions & 0 deletions llm_fal_ai/tests/test_fal_ai_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from odoo.exceptions import UserError
from odoo.tests.common import TransactionCase


Expand Down Expand Up @@ -48,3 +49,147 @@ def test_parse_model_extracts_openapi_schemas(self):
self.assertEqual(parsed["details"]["category"], "text-to-image")
self.assertIn("input_schema", parsed["details"])
self.assertIn("output_schema", parsed["details"])

def test_parse_model_extracts_named_openapi_refs(self):
raw_model = {
"endpoint_id": "fal-ai/bytedance/seedance/v1/pro/image-to-video",
"metadata": {
"category": "image-to-video",
"description": "Image to video",
},
"openapi": {
"paths": {
"/model": {
"post": {
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SeedanceInput"
}
}
},
"required": True,
},
"responses": {
"200": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueueStatus"
}
}
}
}
},
}
},
"/model/requests/{request_id}": {
"get": {
"responses": {
"200": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SeedanceOutput"
}
}
}
}
}
}
},
},
"components": {
"schemas": {
"QueueStatus": {"type": "object"},
"SeedanceInput": {
"type": "object",
"required": ["prompt", "image_url"],
"properties": {
"prompt": {"type": "string"},
"image_url": {"type": "string"},
},
},
"SeedanceOutput": {
"type": "object",
"properties": {"video": {"type": "object"}},
},
}
},
},
}

parsed = self.provider._fal_ai_parse_model(raw_model)

input_schema = parsed["details"]["input_schema"]
output_schema = parsed["details"]["output_schema"]
self.assertEqual(input_schema["required"], ["prompt", "image_url"])
self.assertIn("image_url", input_schema["properties"])
self.assertIn("video", output_schema["properties"])

def test_resolve_inputs_uses_openapi_schema_fallback(self):
model = self.env["llm.model"].create({
"name": "fal-ai/bytedance/seedance/v1/pro/image-to-video",
"provider_id": self.provider.id,
"model_use": "generation",
"details": {
"openapi": {
"paths": {
"/model": {
"post": {
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SeedanceInput"
}
}
}
}
}
}
},
"components": {
"schemas": {
"SeedanceInput": {
"type": "object",
"required": ["prompt", "image_url"],
"properties": {
"prompt": {"type": "string"},
"image_url": {"type": "string"},
},
}
}
},
}
},
})

resolved = self.provider._fal_ai_resolve_inputs(
{"image": "https://example.com/image.jpg", "prompt": "make it move"},
model,
)

self.assertEqual(resolved["image_url"], "https://example.com/image.jpg")
self.assertNotIn("image", resolved)

def test_validate_inputs_fails_before_fal_api_when_prompt_missing(self):
model = self.env["llm.model"].create({
"name": "fal-ai/bytedance/seedance/v1/pro/image-to-video",
"provider_id": self.provider.id,
"model_use": "generation",
"details": {
"input_schema": {
"type": "object",
"required": ["prompt", "image_url"],
"properties": {
"prompt": {"type": "string"},
"image_url": {"type": "string"},
},
}
},
})

with self.assertRaisesRegex(UserError, "prompt"):
self.provider._fal_ai_validate_inputs({"image_url": "x"}, model)