diff --git a/ai_models/admin.py b/ai_models/admin.py index eca275f5a..fd297fd89 100644 --- a/ai_models/admin.py +++ b/ai_models/admin.py @@ -1,17 +1,18 @@ from django.contrib import admin from django.db import models -from ai_models.models import VideoModelSpec +from ai_models.models import AIModelSpec from gooeysite.custom_widgets import JSONEditorWidget from usage_costs.admin import ModelPricingAdmin -@admin.register(VideoModelSpec) -class VideoModelSpecAdmin(admin.ModelAdmin): +@admin.register(AIModelSpec) +class AIModelSpecAdmin(admin.ModelAdmin): list_display = [ "name", "label", "model_id", + "category", "created_at", "updated_at", ] @@ -19,15 +20,14 @@ class VideoModelSpecAdmin(admin.ModelAdmin): list_filter = [ "pricing__category", "pricing__provider", + "category", "created_at", "updated_at", ] - search_fields = [ - "name", - "label", - "model_id", - ] + [f"pricing__{field}" for field in ModelPricingAdmin.search_fields] + search_fields = ["name", "label", "model_id", "category"] + [ + f"pricing__{field}" for field in ModelPricingAdmin.search_fields + ] autocomplete_fields = ["pricing"] readonly_fields = [ @@ -40,6 +40,7 @@ class VideoModelSpecAdmin(admin.ModelAdmin): "Model Information", { "fields": [ + "category", "name", "label", "model_id", diff --git a/ai_models/migrations/0003_aimodelspec_delete_videomodelspec.py b/ai_models/migrations/0003_aimodelspec_delete_videomodelspec.py new file mode 100644 index 000000000..8a7d4fb0f --- /dev/null +++ b/ai_models/migrations/0003_aimodelspec_delete_videomodelspec.py @@ -0,0 +1,42 @@ +# Generated by Django 5.1.3 on 2025-12-01 10:57 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai_models", "0002_alter_videomodelspec_pricing"), + ("usage_costs", "0037_alter_modelpricing_model_name"), + ] + + operations = [ + migrations.RenameModel("VideoModelSpec", "AIModelSpec"), + migrations.AlterModelOptions( + name="aimodelspec", + options={"verbose_name": "AI Model Spec"}, + ), + migrations.AddField( + model_name="aimodelspec", + name="category", + field=models.IntegerField( + choices=[(1, "🎥 Video"), (2, "🎵 Audio")], + default=1, + help_text="Model category: generates Audio, Video, etc.", + null=True, + ), + ), + migrations.AlterField( + model_name="aimodelspec", + name="pricing", + field=models.ForeignKey( + blank=True, + default=None, + help_text="The pricing of the model. Only for display purposes. The actual pricing lookup uses the model_id, so make sure the video model's model_id matches the pricing's model_id.To setup a price multiplier, create a WorkflowMetadata for the workflow and set the price_multiplier field.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="ai_model_specs", + to="usage_costs.modelpricing", + ), + ), + ] diff --git a/ai_models/models.py b/ai_models/models.py index 2c2dce979..10520c2aa 100644 --- a/ai_models/models.py +++ b/ai_models/models.py @@ -1,7 +1,18 @@ from django.db import models -class VideoModelSpec(models.Model): +class AIModelSpec(models.Model): + class Categories(models.IntegerChoices): + video = (1, "🎥 Video") + audio = (2, "🎵 Audio") + + category = models.IntegerField( + choices=Categories.choices, + help_text="Model category: generates Audio, Video, etc.", + null=True, + default=Categories.video, + ) + name = models.TextField( unique=True, help_text="The name of the model to be used in user-facing API calls. WARNING: Don't edit this field after it's been used in a workflow.", @@ -20,7 +31,7 @@ class VideoModelSpec(models.Model): pricing = models.ForeignKey( "usage_costs.ModelPricing", on_delete=models.SET_NULL, - related_name="video_model_specs", + related_name="ai_model_specs", null=True, blank=True, default=None, @@ -32,5 +43,8 @@ class VideoModelSpec(models.Model): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) + class Meta: + verbose_name = "AI Model Spec" + def __str__(self): return f"{self.label} ({self.model_id})" diff --git a/recipes/VideoGenPage.py b/recipes/VideoGenPage.py index 042823fae..fae5324ac 100644 --- a/recipes/VideoGenPage.py +++ b/recipes/VideoGenPage.py @@ -1,27 +1,33 @@ from __future__ import annotations import json +import os +import tempfile import typing from concurrent.futures import ThreadPoolExecutor from queue import Queue from textwrap import dedent import gooey_gui as gui +import requests from django.db.models import Q from pydantic import BaseModel from requests.utils import CaseInsensitiveDict -from ai_models.models import VideoModelSpec +from ai_models.models import AIModelSpec from bots.models import Workflow +from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.base import BasePage -from daras_ai_v2.exceptions import UserError +from daras_ai_v2.exceptions import UserError, ffmpeg, ffprobe from daras_ai_v2.fal_ai import generate_on_fal from daras_ai_v2.functional import get_initializer +from daras_ai_v2.language_model_openai_realtime import yield_from from daras_ai_v2.pydantic_validation import HttpUrlStr from daras_ai_v2.safety_checker import SAFETY_CHECKER_MSG, safety_checker from daras_ai_v2.variables_widget import render_prompt_vars from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku +from widgets.switch_with_section import switch_with_section class VideoGenPage(BasePage): @@ -35,6 +41,10 @@ class RequestModel(BasePage.RequestModel): selected_models: list[str] inputs: dict[str, typing.Any] + # Audio generation settings + selected_audio_model: str | None = None + audio_inputs: dict[str, typing.Any] | None = None + class ResponseModel(BaseModel): output_videos: dict[str, HttpUrlStr] @@ -44,23 +54,17 @@ def run_v2( if not request.selected_models: raise UserError("Please select at least one model") - for key in ["prompt", "negative_prompt"]: - if not request.inputs.get(key): - continue - # Render any template variables in the prompt - request.inputs[key] = render_prompt_vars( - request.inputs[key], gui.session_state - ) - # Safety check if not disabled - if self.request.user.disable_safety_checker: - continue - yield "Running safety checker..." - safety_checker(text=request.inputs[key]) + self.run_safety_checker(request) q = Q() for model_name in request.selected_models: q |= Q(name__icontains=model_name) - models = VideoModelSpec.objects.filter(q) + models = AIModelSpec.objects.filter(q) + + if request.selected_audio_model: + audio_model = AIModelSpec.objects.get(name=request.selected_audio_model) + else: + audio_model = None progress_q = Queue() progress = {model.model_id: "" for model in models} @@ -68,32 +72,61 @@ def run_v2( with ThreadPoolExecutor( max_workers=len(models), initializer=get_initializer() ) as pool: - fs = { - model: pool.submit( + fs = [ + pool.submit( generate_video, - model, - request.inputs | dict(enable_safety_checker=False), - progress_q, + model=model, + inputs=request.inputs | dict(enable_safety_checker=False), + audio_model=audio_model, + audio_inputs=request.audio_inputs, + progress_q=progress_q, + output_videos=response.output_videos, ) for model in models - } + ] # print(f"{fs=}") yield f"Running {', '.join(model.label for model in models)}" - while not all(fut.done() for fut in fs.values()): + while not all(fut.done() for fut in fs): model_id, msg = progress_q.get() if not msg: continue progress[model_id] = msg yield "\n".join(progress.values()) + for fut in fs: + fut.result() - for model, fut in fs.items(): - response.output_videos[model.name] = fut.result() - # print(f"{response.output_videos=}") + def run_safety_checker(self, request: VideoGenPage.RequestModel): + if self.request.user.disable_safety_checker: + return + for inputs in [request.inputs, request.audio_inputs]: + if not inputs: + continue + for key in ["prompt", "negative_prompt"]: + text = inputs.get(key) + if not text: + continue + # Render any template variables in the prompt + inputs[key] = render_prompt_vars(inputs[key], gui.session_state) + yield "Running safety checker..." + safety_checker(text=text) def render(self): self.available_models = CaseInsensitiveDict( - {model.name: model for model in VideoModelSpec.objects.all()} + { + model.name: model + for model in AIModelSpec.objects.filter( + category=AIModelSpec.Categories.video + ) + } + ) + self.available_audio_models = CaseInsensitiveDict( + { + model.name: model + for model in AIModelSpec.objects.filter( + category=AIModelSpec.Categories.audio + ) + } ) super().render() @@ -117,6 +150,16 @@ def additional_notes(self) -> str | None: def render_form_v2(self): render_video_gen_form(self.available_models) + generate_audio = switch_with_section( + label="##### Generate sound", + control_keys=["selected_audio_model"], + render_section=self.generate_audio_settings, + disabled=not gui.session_state.get("selected_models"), + ) + if not generate_audio: + gui.session_state["selected_audio_model"] = None + gui.session_state["audio_inputs"] = None + def render_output(self): self.render_run_preview_output(gui.session_state, show_download_button=True) @@ -178,11 +221,19 @@ def render_usage_guide(self): """ ) + def generate_audio_settings(self): + render_audio_gen_form(self.available_audio_models) + def generate_video( - model: VideoModelSpec, inputs: dict, progress_q: Queue[tuple[str, str | None]] -) -> str: - # print(f"{model=} {inputs=} {event=} {progress=} {output_videos=}") + model: AIModelSpec, + inputs: dict, + audio_model: AIModelSpec | None, + audio_inputs: dict[str, typing.Any] | None, + progress_q: Queue[tuple[str, str | None]], + output_videos: dict[str, str], +): + # print(f"{model=} {inputs=} {audio_model=} {audio_inputs=}") gen = generate_on_fal(model.model_id, inputs) try: while True: @@ -198,13 +249,56 @@ def generate_video( elif not video_out: raise UserError(f"No video output: {out}") record_cost_auto(model.model_id, ModelSku.video_generation, 1) - return video_out["url"] - # print(f"{output_videos[model.name]=}") + video_url = get_url_from_result(video_out) + output_videos[model.name] = video_url + # print(f"{video_url=}") + if audio_model and audio_inputs: + progress_q.put((model.model_id, f"Generating audio with {audio_model}...")) + output_videos[model.name] = generate_audio( + video_url, inputs, audio_model, audio_inputs + ) finally: progress_q.put((model.model_id, None)) -def render_video_gen_form(available_models: dict[str, VideoModelSpec]): +def generate_audio( + video_url: str, + inputs: dict, + audio_model: AIModelSpec, + audio_inputs: dict[str, typing.Any], +) -> str: + duration = float(ffprobe(video_url)["streams"][0]["duration"]) + duration_props = resolve_field_anyof( + extract_openapi_schema(audio_model.schema, "request") + .get("properties", {}) + .get("duration", {}) + ) + minimum = duration_props.get("minimum") + maximum = duration_props.get("maximum") + if minimum: + duration = max(minimum, duration) + if maximum: + duration = min(maximum, duration) + + payload = {"video_url": video_url, "duration": duration} | audio_inputs + if not payload.get("prompt"): + payload["prompt"] = inputs.get("prompt") + res = yield_from(generate_on_fal(audio_model.model_id, payload)) + record_cost_auto(audio_model.model_id, ModelSku.video_generation, 1) + res_video = get_url_from_result(res.get("video")) + res_audio = get_url_from_result(res.get("audio")) + + if res_video: + return res_video + elif res_audio: + audio_url = get_url_from_result(res_audio) + filename = f"{audio_model.label}_merged.mp4" + return merge_audio_and_video(filename, audio_url, video_url) + else: + raise ValueError(f"No video/audio output from {audio_model.name}") + + +def render_video_gen_form(available_models: dict[str, AIModelSpec]): # normalize the selected model names gui.session_state["selected_models"] = [ model.name @@ -218,27 +312,80 @@ def render_video_gen_form(available_models: dict[str, VideoModelSpec]): key="selected_models", allow_none=True, ) + render_fields( + key="inputs", available_models=available_models, selected_models=selected_models + ) + + +def render_audio_gen_form(available_audio_models: dict[str, AIModelSpec]): + with gui.div(className="pt-1 pb-1"): + gui.caption( + "Automatically add sound effects. Note: Overwrites audio if present in video. " + ) + + current_audio_model = gui.session_state.get("selected_audio_model") + if current_audio_model and current_audio_model not in available_audio_models: + gui.session_state["selected_audio_model"] = None + + if not available_audio_models: + return + + selected_audio_model = gui.selectbox( + label="###### Sound generation model", + options=list(available_audio_models.keys()), + format_func=lambda x: available_audio_models[x].label, + key="selected_audio_model", + allow_none=False, + disabled=not available_audio_models + or not gui.session_state.get("selected_models"), + ) + + render_fields( + key="audio_inputs", + available_models=available_audio_models, + selected_models=[selected_audio_model], + skip_fields=["video_url", "duration"], + ) + + +def render_fields( + key: str, + available_models: dict[str, AIModelSpec], + selected_models: list[str], + skip_fields: typing.Iterable[str] = (), +): models = list( filter(None, (available_models.get(name) for name in selected_models)) ) if not models: - return + return {} + + try: + model_input_schemas = [ + schema + for model in models + if (schema := extract_openapi_schema(model.schema, "request")) + ] + except Exception as e: + gui.error(f"Error getting input fields: {e}") + return {} common_fields = set.intersection( - *(set(model.schema["properties"]) for model in models) + *(set(schema.get("properties", {})) for schema in model_input_schemas) ) - schema = models[0].schema + + schema = model_input_schemas[0] required_fields = set(schema.get("required", [])) ordered_fields = schema.get("x-fal-order-properties") or list(common_fields) ordered_fields.sort(key=lambda x: x not in required_fields) - old_inputs = gui.session_state.get("inputs", {}) + old_inputs = gui.session_state.get(key) or {} new_inputs = {} for name in ordered_fields: - if name not in common_fields: + if name not in common_fields or name in skip_fields: continue - field = models[0].schema["properties"][name] + field = model_input_schemas[0]["properties"][name] label = field.get("title") or name.title() if name in required_fields: label = "##### " + label @@ -248,23 +395,17 @@ def render_video_gen_form(available_models: dict[str, VideoModelSpec]): field=field, name=name, label=label, value=value ) - gui.session_state["inputs"] = new_inputs + gui.session_state[key] = new_inputs -def render_field( - *, - field: dict, - name: str, - label: str, - value: typing.Any, -): +def render_field(*, field: dict, name: str, label: str, value: typing.Any): description = field.get("description") if description: help_text = dedent(description) else: help_text = None - - match get_field_type(field): + field = resolve_field_anyof(field) + match field["type"]: case "array" if "lora" in name or "url" in name: return gui.file_uploader( label=label, @@ -280,28 +421,33 @@ def render_field( ) case ("string" | "integer" | "number") as _type if field.get("enum"): v = gui.selectbox( - label=label, - value=value, - help=help_text, - options=field["enum"], + label=label, value=value, help=help_text, options=field["enum"] ) pytype = {"string": str, "integer": int, "number": float}[_type] return pytype(v) case "string": - return gui.text_area( - label=label, - value=value, - help=help_text, - ) + return gui.text_area(label=label, value=value, help=help_text) case "integer": - return gui.number_input( - label=label, - value=value, - help=help_text, - min_value=field.get("minimum"), - max_value=field.get("maximum"), - step=1, - ) + minimum = field.get("minimum") + maximum = field.get("maximum") + if minimum and maximum: + return gui.slider( + label=label, + min_value=minimum, + max_value=maximum, + value=value, + step=1, + help=help_text, + ) + else: + return gui.number_input( + label=label, + value=value, + help=help_text, + min_value=minimum, + max_value=maximum, + step=1, + ) case "number": return gui.number_input( label=label, @@ -312,11 +458,7 @@ def render_field( step=0.1, ) case "boolean": - return gui.checkbox( - label=label, - value=value, - help=help_text, - ) + return gui.checkbox(label=label, value=value, help=help_text) case "object": try: json_str = json.dumps(value, indent=2) @@ -336,12 +478,98 @@ def render_field( gui.error("Value must be a JSON object") -def get_field_type(field: dict) -> str: - try: - return field["type"] - except KeyError: - for props in field.get("anyOf", []): - inner_type = props.get("type") - if inner_type and inner_type != "null": - return inner_type - return "object" +def resolve_field_anyof(field: dict) -> dict: + if field.get("type"): + return field + for props in field.get("anyOf", []): + inner_type = props.get("type") + if inner_type and inner_type != "null": + return props + return {"type": "object"} + + +def extract_openapi_schema( + openapi_json: dict, schema_type: typing.Literal["request", "response"] +) -> dict | None: + if openapi_json.get("properties"): + return openapi_json + + endpoint_id = ( + openapi_json.get("info", {}).get("x-fal-metadata", {}).get("endpointId") + ) + + paths = openapi_json.get("paths", {}) + + if schema_type == "request": + path_key = f"/{endpoint_id}" + method_data = paths.get(path_key, {}).get("post", {}) + schema_ref = ( + method_data.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + .get("$ref") + ) + else: # output + path_key = f"/{endpoint_id}/requests/{{request_id}}" + method_data = paths.get(path_key, {}).get("get", {}) + schema_ref = ( + method_data.get("responses", {}) + .get("200", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + .get("$ref") + ) + + if not schema_ref: + return {} + + schema_name = schema_ref.split("/")[-1] + return openapi_json.get("components", {}).get("schemas", {}).get(schema_name, {}) + + +def get_url_from_result(result: dict | list | str | None) -> str | None: + if not result: + return None + match result: + case list(): + return result[0] + case dict(): + return result.get("url") + case _: + return result + + +def merge_audio_and_video(filename: str, audio_url: str, video_url: str) -> str: + with tempfile.TemporaryDirectory() as tmpdir: + video_path = os.path.join(tmpdir, "video.mp4") + audio_path = os.path.join(tmpdir, "audio.wav") + output_path = os.path.join(tmpdir, "merged_video.mp4") + + video_response = requests.get(video_url) + video_response.raise_for_status() + with open(video_path, "wb") as f: + f.write(video_response.content) + + audio_response = requests.get(audio_url) + audio_response.raise_for_status() + with open(audio_path, "wb") as f: + f.write(audio_response.content) + + ffmpeg( + "-i", video_path, + "-stream_loop", "-1", + "-i", audio_path, + "-c:v", "copy", + "-c:a", "aac", + "-map", "0:v:0", + "-map", "1:a:0", + "-shortest", + output_path, + ) # fmt:skip + + with open(output_path, "rb") as f: + merged_video_bytes = f.read() + + return upload_file_from_bytes(filename, merged_video_bytes, "video/mp4")