From 04916caf7a391464c0c0dcf1fc1ae66f71977aab Mon Sep 17 00:00:00 2001
From: Jason Praful <jason.praful@gmail.com>
Date: Sun, 24 Nov 2024 02:08:19 +0000
Subject: [PATCH] add types for flux text to image model

---
 types/defines/ai.d.ts | 23 +++++++++++++++--------
 1 file changed, 15 insertions(+), 8 deletions(-)

diff --git a/types/defines/ai.d.ts b/types/defines/ai.d.ts
index 812b612b795..75b05e05eab 100644
--- a/types/defines/ai.d.ts
+++ b/types/defines/ai.d.ts
@@ -156,10 +156,16 @@ export type AiTextToImageInput = {
   strength?: number;
   guidance?: number;
 };
-export type AiTextToImageOutput = ReadableStream<Uint8Array>;
-export declare abstract class BaseAiTextToImage {
+export type AiTextToImageOutput<Model extends BaseAiTextToImageModels> =
+  Model extends "@cf/black-forest-labs/flux-1-schnell"
+    ? { image: string }
+    : ReadableStream<Uint8Array>;
+
+export declare abstract class BaseAiTextToImage<
+  Model extends BaseAiTextToImageModels,
+> {
   inputs: AiTextToImageInput;
-  postProcessedOutputs: AiTextToImageOutput;
+  postProcessedOutputs: AiTextToImageOutput<Model>;
 }
 export type AiTranslationInput = {
   text: string;
@@ -193,7 +199,8 @@ export type BaseAiTextToImageModels =
   | "@cf/runwayml/stable-diffusion-v1-5-inpainting"
   | "@cf/runwayml/stable-diffusion-v1-5-img2img"
   | "@cf/lykon/dreamshaper-8-lcm"
-  | "@cf/bytedance/stable-diffusion-xl-lightning";
+  | "@cf/bytedance/stable-diffusion-xl-lightning"
+  | "@cf/black-forest-labs/flux-1-schnell";
 export type BaseAiTextEmbeddingsModels =
   | "@cf/baai/bge-small-en-v1.5"
   | "@cf/baai/bge-base-en-v1.5"
@@ -252,11 +259,11 @@ export declare abstract class Ai {
     inputs: BaseAiTextClassification["inputs"],
     options?: AiOptions
   ): Promise<BaseAiTextClassification["postProcessedOutputs"]>;
-  run(
-    model: BaseAiTextToImageModels,
-    inputs: BaseAiTextToImage["inputs"],
+  run<Model extends BaseAiTextToImageModels>(
+    model: Model,
+    inputs: BaseAiTextToImage<Model>["inputs"],
     options?: AiOptions
-  ): Promise<BaseAiTextToImage["postProcessedOutputs"]>;
+  ): Promise<BaseAiTextToImage<Model>["postProcessedOutputs"]>;
   run(
     model: BaseAiTextEmbeddingsModels,
     inputs: BaseAiTextEmbeddings["inputs"],