Skip to content

Commit 218d001

Browse files
authored
feat (provider): Add maxImagesPerCall setting to all image models/providers. (#4364)
1 parent 971be76 commit 218d001

17 files changed

Lines changed: 318 additions & 76 deletions

.changeset/nervous-gifts-bake.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
'@ai-sdk/google-vertex': patch
3+
'@ai-sdk/fireworks': patch
4+
'@ai-sdk/replicate': patch
5+
'@ai-sdk/openai': patch
6+
---
7+
8+
feat (provider): Add maxImagesPerCall setting to all image providers.

content/docs/03-ai-sdk-core/35-image-generation.mdx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ const { images } = await generateImage({
8383
parallel) to generate the requested number of images.
8484
</Note>
8585

86+
Each image model has an internal limit on how many images it can generate in a single API call. The AI SDK manages this automatically by batching requests appropriately when you request multiple images using the `n` parameter. By default, the SDK uses provider-documented limits (for example, DALL-E 3 can only generate 1 image per call, while DALL-E 2 supports up to 10).
87+
88+
If needed, you can override this behavior using the `maxImagesPerCall` setting when configuring your model. This is particularly useful when working with new or custom models where the default batch size might not be optimal:
89+
90+
```tsx
91+
const model = openai.image('dall-e-2', {
92+
maxImagesPerCall: 5, // Override the default batch size
93+
});
94+
95+
const { images } = await generateImage({
96+
model,
97+
prompt: 'Santa Claus driving a Cadillac',
98+
n: 10, // Will make 2 calls of 5 images each
99+
});
100+
```
101+
86102
### Providing a Seed
87103

88104
You can provide a seed to the `generateImage` function to control the output of the image generation process.

packages/fireworks/src/fireworks-image-model.test.ts

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,35 @@ import { FetchFunction } from '@ai-sdk/provider-utils';
22
import { createTestServer } from '@ai-sdk/provider-utils/test';
33
import { describe, expect, it } from 'vitest';
44
import { FireworksImageModel } from './fireworks-image-model';
5+
import { FireworksImageSettings } from './fireworks-image-settings';
56

67
const prompt = 'A cute baby sea otter';
78

89
function createBasicModel({
910
headers,
1011
fetch,
12+
settings,
1113
}: {
1214
headers?: () => Record<string, string>;
1315
fetch?: FetchFunction;
16+
settings?: FireworksImageSettings;
1417
} = {}) {
15-
return new FireworksImageModel('accounts/fireworks/models/flux-1-dev-fp8', {
16-
provider: 'fireworks',
17-
baseURL: 'https://api.example.com',
18-
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
19-
fetch,
20-
});
18+
return new FireworksImageModel(
19+
'accounts/fireworks/models/flux-1-dev-fp8',
20+
settings ?? {},
21+
{
22+
provider: 'fireworks',
23+
baseURL: 'https://api.example.com',
24+
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
25+
fetch,
26+
},
27+
);
2128
}
2229

2330
function createSizeModel() {
2431
return new FireworksImageModel(
2532
'accounts/fireworks/models/playground-v2-5-1024px-aesthetic',
33+
{},
2634
{
2735
provider: 'fireworks',
2836
baseURL: 'https://api.size-example.com',
@@ -298,5 +306,21 @@ describe('FireworksImageModel', () => {
298306
expect(model.specificationVersion).toBe('v1');
299307
expect(model.maxImagesPerCall).toBe(1);
300308
});
309+
310+
it('should use maxImagesPerCall from settings', () => {
311+
const model = createBasicModel({
312+
settings: {
313+
maxImagesPerCall: 4,
314+
},
315+
});
316+
317+
expect(model.maxImagesPerCall).toBe(4);
318+
});
319+
320+
it('should default maxImagesPerCall to 1 when not specified', () => {
321+
const model = createBasicModel();
322+
323+
expect(model.maxImagesPerCall).toBe(1);
324+
});
301325
});
302326
});

packages/fireworks/src/fireworks-image-model.ts

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,10 @@ import {
1010
postJsonToApi,
1111
ResponseHandler,
1212
} from '@ai-sdk/provider-utils';
13-
14-
// https://fireworks.ai/models?type=image
15-
export type FireworksImageModelId =
16-
| 'accounts/fireworks/models/flux-1-dev-fp8'
17-
| 'accounts/fireworks/models/flux-1-schnell-fp8'
18-
| 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic'
19-
| 'accounts/fireworks/models/japanese-stable-diffusion-xl'
20-
| 'accounts/fireworks/models/playground-v2-1024px-aesthetic'
21-
| 'accounts/fireworks/models/SSD-1B'
22-
| 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0'
23-
| (string & {});
13+
import {
14+
FireworksImageModelId,
15+
FireworksImageSettings,
16+
} from './fireworks-image-settings';
2417

2518
interface FireworksImageModelBackendConfig {
2619
urlFormat: 'workflows' | 'image_generation';
@@ -141,10 +134,13 @@ export class FireworksImageModel implements ImageModelV1 {
141134
return this.config.provider;
142135
}
143136

144-
readonly maxImagesPerCall = 1;
137+
get maxImagesPerCall(): number {
138+
return this.settings.maxImagesPerCall ?? 1;
139+
}
145140

146141
constructor(
147142
readonly modelId: FireworksImageModelId,
143+
readonly settings: FireworksImageSettings,
148144
private config: FireworksImageModelConfig,
149145
) {}
150146

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// https://fireworks.ai/models?type=image
2+
export type FireworksImageModelId =
3+
| 'accounts/fireworks/models/flux-1-dev-fp8'
4+
| 'accounts/fireworks/models/flux-1-schnell-fp8'
5+
| 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic'
6+
| 'accounts/fireworks/models/japanese-stable-diffusion-xl'
7+
| 'accounts/fireworks/models/playground-v2-1024px-aesthetic'
8+
| 'accounts/fireworks/models/SSD-1B'
9+
| 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0'
10+
| (string & {});
11+
12+
export interface FireworksImageSettings {
13+
/**
14+
Override the maximum number of images per call (default 1)
15+
*/
16+
maxImagesPerCall?: number;
17+
}

packages/fireworks/src/fireworks-provider.test.ts

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
22
import { createFireworks } from './fireworks-provider';
3-
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
43
import { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider';
54
import { loadApiKey } from '@ai-sdk/provider-utils';
65
import {
76
OpenAICompatibleChatLanguageModel,
87
OpenAICompatibleCompletionLanguageModel,
98
OpenAICompatibleEmbeddingModel,
109
} from '@ai-sdk/openai-compatible';
10+
import { FireworksImageModel } from './fireworks-image-model';
1111

1212
// Add type assertion for the mocked class
1313
const OpenAICompatibleChatLanguageModelMock =
@@ -24,6 +24,10 @@ vi.mock('@ai-sdk/provider-utils', () => ({
2424
withoutTrailingSlash: vi.fn(url => url),
2525
}));
2626

27+
vi.mock('./fireworks-image-model', () => ({
28+
FireworksImageModel: vi.fn(),
29+
}));
30+
2731
describe('FireworksProvider', () => {
2832
let mockLanguageModel: LanguageModelV1;
2933
let mockEmbeddingModel: EmbeddingModelV1<string>;
@@ -125,4 +129,54 @@ describe('FireworksProvider', () => {
125129
expect(model).toBeInstanceOf(OpenAICompatibleEmbeddingModel);
126130
});
127131
});
132+
133+
describe('image', () => {
134+
it('should construct an image model with correct configuration', () => {
135+
const provider = createFireworks();
136+
const modelId = 'accounts/fireworks/models/flux-1-dev-fp8';
137+
const settings = { maxImagesPerCall: 2 };
138+
139+
const model = provider.image(modelId, settings);
140+
141+
expect(model).toBeInstanceOf(FireworksImageModel);
142+
expect(FireworksImageModel).toHaveBeenCalledWith(
143+
modelId,
144+
settings,
145+
expect.objectContaining({
146+
provider: 'fireworks.image',
147+
baseURL: 'https://api.fireworks.ai/inference/v1',
148+
}),
149+
);
150+
});
151+
152+
it('should use default settings when none provided', () => {
153+
const provider = createFireworks();
154+
const modelId = 'accounts/fireworks/models/flux-1-dev-fp8';
155+
156+
const model = provider.image(modelId);
157+
158+
expect(model).toBeInstanceOf(FireworksImageModel);
159+
expect(FireworksImageModel).toHaveBeenCalledWith(
160+
modelId,
161+
{},
162+
expect.any(Object),
163+
);
164+
});
165+
166+
it('should respect custom baseURL', () => {
167+
const customBaseURL = 'https://custom.api.fireworks.ai';
168+
const provider = createFireworks({ baseURL: customBaseURL });
169+
const modelId = 'accounts/fireworks/models/flux-1-dev-fp8';
170+
171+
const model = provider.image(modelId);
172+
173+
expect(FireworksImageModel).toHaveBeenCalledWith(
174+
modelId,
175+
expect.any(Object),
176+
expect.objectContaining({
177+
baseURL: customBaseURL,
178+
}),
179+
);
180+
});
181+
});
128182
});

packages/fireworks/src/fireworks-provider.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ import {
2727
FireworksEmbeddingModelId,
2828
FireworksEmbeddingSettings,
2929
} from './fireworks-embedding-settings';
30+
import { FireworksImageModel } from './fireworks-image-model';
3031
import {
31-
FireworksImageModel,
3232
FireworksImageModelId,
33-
} from './fireworks-image-model';
33+
FireworksImageSettings,
34+
} from './fireworks-image-settings';
3435

3536
export type FireworksErrorData = z.infer<typeof fireworksErrorSchema>;
3637

@@ -100,7 +101,10 @@ Creates a text embedding model for text generation.
100101
/**
101102
Creates a model for image generation.
102103
*/
103-
image(modelId: FireworksImageModelId): ImageModelV1;
104+
image(
105+
modelId: FireworksImageModelId,
106+
settings?: FireworksImageSettings,
107+
): ImageModelV1;
104108
}
105109

106110
const defaultBaseURL = 'https://api.fireworks.ai/inference/v1';
@@ -161,8 +165,11 @@ export function createFireworks(
161165
errorStructure: fireworksErrorStructure,
162166
});
163167

164-
const createImageModel = (modelId: FireworksImageModelId) =>
165-
new FireworksImageModel(modelId, {
168+
const createImageModel = (
169+
modelId: FireworksImageModelId,
170+
settings: FireworksImageSettings = {},
171+
) =>
172+
new FireworksImageModel(modelId, settings, {
166173
...getCommonModelConfig('image'),
167174
baseURL: baseURL ?? defaultBaseURL,
168175
});

packages/google-vertex/src/google-vertex-image-model.test.ts

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ import { GoogleVertexImageModel } from './google-vertex-image-model';
44

55
const prompt = 'A cute baby sea otter';
66

7-
const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
8-
provider: 'google-vertex',
9-
baseURL: 'https://api.example.com',
10-
headers: { 'api-key': 'test-key' },
11-
});
7+
const model = new GoogleVertexImageModel(
8+
'imagen-3.0-generate-001',
9+
{},
10+
{
11+
provider: 'google-vertex',
12+
baseURL: 'https://api.example.com',
13+
headers: { 'api-key': 'test-key' },
14+
},
15+
);
1216

1317
describe('GoogleVertexImageModel', () => {
1418
describe('doGenerate', () => {
@@ -53,6 +57,7 @@ describe('GoogleVertexImageModel', () => {
5357

5458
const modelWithHeaders = new GoogleVertexImageModel(
5559
'imagen-3.0-generate-001',
60+
{},
5661
{
5762
provider: 'google-vertex',
5863
baseURL: 'https://api.example.com',
@@ -83,6 +88,34 @@ describe('GoogleVertexImageModel', () => {
8388
});
8489
});
8590

91+
it('should respect maxImagesPerCall setting', () => {
92+
const customModel = new GoogleVertexImageModel(
93+
'imagen-3.0-generate-001',
94+
{ maxImagesPerCall: 2 },
95+
{
96+
provider: 'google-vertex',
97+
baseURL: 'https://api.example.com',
98+
headers: { 'api-key': 'test-key' },
99+
},
100+
);
101+
102+
expect(customModel.maxImagesPerCall).toBe(2);
103+
});
104+
105+
it('should use default maxImagesPerCall when not specified', () => {
106+
const defaultModel = new GoogleVertexImageModel(
107+
'imagen-3.0-generate-001',
108+
{},
109+
{
110+
provider: 'google-vertex',
111+
baseURL: 'https://api.example.com',
112+
headers: { 'api-key': 'test-key' },
113+
},
114+
);
115+
116+
expect(defaultModel.maxImagesPerCall).toBe(4);
117+
});
118+
86119
it('should extract the generated images', async () => {
87120
prepareJsonResponse();
88121

packages/google-vertex/src/google-vertex-image-model.ts

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ import {
88
} from '@ai-sdk/provider-utils';
99
import { z } from 'zod';
1010
import { googleVertexFailedResponseHandler } from './google-vertex-error';
11-
12-
export type GoogleVertexImageModelId =
13-
| 'imagen-3.0-generate-001'
14-
| 'imagen-3.0-fast-generate-001'
15-
| (string & {});
11+
import {
12+
GoogleVertexImageModelId,
13+
GoogleVertexImageSettings,
14+
} from './google-vertex-image-settings';
1615

1716
interface GoogleVertexImageModelConfig {
1817
provider: string;
@@ -25,15 +24,18 @@ interface GoogleVertexImageModelConfig {
2524
export class GoogleVertexImageModel implements ImageModelV1 {
2625
readonly specificationVersion = 'v1';
2726

28-
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list
29-
readonly maxImagesPerCall = 4;
30-
3127
get provider(): string {
3228
return this.config.provider;
3329
}
3430

31+
get maxImagesPerCall(): number {
32+
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list
33+
return this.settings.maxImagesPerCall ?? 4;
34+
}
35+
3536
constructor(
3637
readonly modelId: GoogleVertexImageModelId,
38+
readonly settings: GoogleVertexImageSettings,
3739
private config: GoogleVertexImageModelConfig,
3840
) {}
3941

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
export type GoogleVertexImageModelId =
2+
| 'imagen-3.0-generate-001'
3+
| 'imagen-3.0-fast-generate-001'
4+
| (string & {});
5+
6+
export interface GoogleVertexImageSettings {
7+
/**
8+
Override the maximum number of images per call (default 4)
9+
*/
10+
maxImagesPerCall?: number;
11+
}

0 commit comments

Comments
 (0)