Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/groq/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"groq-sdk": "^0.19.0"
},
"peerDependencies": {
"genkit": "^0.9.0 || ^1.0.0"
"genkit": "^1.19.3"
},
"devDependencies": {
"@types/hast": "^3.0.4",
Expand Down
26 changes: 13 additions & 13 deletions plugins/groq/src/groq_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import { ChatCompletion } from 'groq-sdk/resources/chat/index.mjs';
import {
GenerateRequest,
GenerationCommonConfigSchema,
Genkit,
Message,
MessageData,
Part,
Expand All @@ -40,6 +39,7 @@ import {
modelRef,
ToolDefinition,
} from 'genkit/model';
import { model } from 'genkit/plugin';

export const GroqConfigSchema = GenerationCommonConfigSchema.extend({
stream: z.boolean().optional(),
Expand Down Expand Up @@ -572,29 +572,29 @@ export function toGroqRequestBody(
}

/**
* Defines a Groq model.
* Creates a Groq model action.
*
* @param name - The name of the model.
* @param client - The Groq client.
* @returns The model.
* @returns The model action.
*/
export function groqModel(ai: Genkit, name: string, client: Groq) {
const model = SUPPORTED_GROQ_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);
export function createGroqModel(name: string, client: Groq) {
const modelRef = SUPPORTED_GROQ_MODELS[name];
if (!modelRef) throw new Error(`Unsupported model: ${name}`);
const modelId = `groq/${name}`;

return ai.defineModel(
return model(
{
name: modelId,
...model.info,
configSchema: model.configSchema,
...modelRef.info,
configSchema: modelRef.configSchema,
},
async (
request,
streamingCallback?: StreamingCallback<GenerateResponseChunkData>
) => {
async (request, options: any) => {
let response: ChatCompletion;
const body = toGroqRequestBody(name, request);
const streamingCallback:
| StreamingCallback<GenerateResponseChunkData>
| undefined = options?.streamingCallback;
if (streamingCallback) {
if (request.output?.format === 'json') {
throw new Error(
Expand Down
54 changes: 33 additions & 21 deletions plugins/groq/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// Import necessary types and functions for Groq SDK integration.
import Groq from 'groq-sdk';
import {
groqModel,
createGroqModel,
llama3x70b,
llama3x8b,
llamaGuard3x8b,
Expand All @@ -32,8 +32,7 @@ import {
deepseekR1DistillLlamax70b,
SUPPORTED_GROQ_MODELS,
} from './groq_models';
import { Genkit } from 'genkit';
import { genkitPlugin } from 'genkit/plugin';
import { genkitPluginV2 } from 'genkit/plugin';

// Export models for direct access
export {
Expand Down Expand Up @@ -93,26 +92,39 @@ export interface PluginOptions {
* @returns An object containing the models initialized with the Groq client.
*/
export const groq = (options?: PluginOptions) =>
genkitPlugin('groq', async (ai: Genkit) => {
const apiKey = options?.apiKey || process.env.GROQ_API_KEY;
if (!apiKey) {
throw new Error(
'Please provide the API key or set the GROQ_API_KEY environment variable'
);
}
genkitPluginV2({
name: 'groq',
init: async () => {
const apiKey = options?.apiKey || process.env.GROQ_API_KEY;
if (!apiKey) {
throw new Error(
'Please provide the API key or set the GROQ_API_KEY environment variable'
);
}

// Initialize Groq client
const client = new Groq({
baseURL: options?.baseURL || process.env.GROQ_BASE_URL, // Optional base URL with environment variable fallback
apiKey, // API key retrieved from options or environment
timeout: options?.timeout, // Optional timeout
maxRetries: options?.maxRetries, // Optional max retries
});
// Initialize Groq client
const client = new Groq({
baseURL: options?.baseURL || process.env.GROQ_BASE_URL, // Optional base URL with environment variable fallback
apiKey, // API key retrieved from options or environment
timeout: options?.timeout, // Optional timeout
maxRetries: options?.maxRetries, // Optional max retries
});

// Register each model with the Genkit instance
for (const name of Object.keys(SUPPORTED_GROQ_MODELS)) {
groqModel(ai, name, client);
}
// Create model actions for each supported model
const models: any[] = [];
for (const name of Object.keys(SUPPORTED_GROQ_MODELS)) {
models.push(createGroqModel(name, client));
}

return models;
},
list: async () => {
return Object.keys(SUPPORTED_GROQ_MODELS).map((name) => ({
name: `groq/${name}`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the groq/ part?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, groq/ was initially being appended to the model name.

The prefix groq/ was added here:
https://github.com/BloomLabsInc/genkit-plugins/pull/9/files#diff-9c9efbee9ae1c4a79760d6d2d0f89c9d46d538f645235c21e3f51f3c4c9e35aeR454-R457

When the initialized Genkit object is logged while using the genkitx-groq package (which has the v1 plugin under the hood), the resulting model name will include this prefix.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary anymore in v2.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, removing it shortly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HassanBahati I think we're going for namespace: 'pluginName' and name: 'modelName' instead. So can we apply that here and to your other migrations too.

For example:

{
  name: 'llama3-70b-8192',
  namespace: 'groq'
  ...
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CorieW, i've updated

type: 'model' as const,
info: SUPPORTED_GROQ_MODELS[name].info,
}));
},
});

// Default export for plugin usage
Expand Down
43 changes: 43 additions & 0 deletions plugins/groq/tests/groq_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
toGroqMessages,
} from '../src/groq_models';
import { ChatCompletionCreateParamsBase } from 'groq-sdk/resources/chat/completions.mjs';
import { groq } from '../src/index';

describe('toGroqRole', () => {
it('should convert user role correctly', () => {
Expand Down Expand Up @@ -97,3 +98,45 @@ describe('toGroqRequestBody', () => {
});
});
});

describe('Groq Plugin', () => {
it('should create plugin with v2 API', () => {
const plugin = groq({ apiKey: 'test-key' });

// Check that the plugin has the expected structure
assert.strictEqual(plugin.name, 'groq');
assert(typeof plugin.init === 'function');
assert(typeof plugin.list === 'function');
});

it('should list available models', async () => {
const plugin = groq({ apiKey: 'test-key' });
const models = await plugin.list?.();

// Check that we get a list of models
assert(Array.isArray(models));
assert(models.length > 0);

// Check that each model has the expected structure
for (const model of models) {
assert.strictEqual((model as any).type, 'model');
assert(typeof model.name === 'string');
assert(model.name.startsWith('groq/'));
assert(typeof (model as any).info === 'object');
}
});

it('should initialize models', async () => {
const plugin = groq({ apiKey: 'test-key' });
const models = await plugin.init?.();

// Check that we get an array of model actions
assert(Array.isArray(models));
assert(models.length > 0);

// Check that each model is a function (action)
for (const model of models) {
assert(typeof model === 'function');
}
});
});