Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions app/lib/chat_models/ai-mask-chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import {
BaseMessage, AIMessageChunk
} from "@langchain/core/messages";
import { AIMaskClient, ChatCompletionMessageParam } from '@ai-mask/sdk';
import { ChatGenerationChunk } from "@langchain/core/outputs";

export interface AIMaskInputs extends BaseChatModelParams {
modelId: string
temperature?: number;
aiMaskClient?: AIMaskClient
appName?: string
}

export interface AIMaskCallOptions extends BaseLanguageModelCallOptions {
}

function convertMessages(messages: BaseMessage[]): ChatCompletionMessageParam[] {
return messages.map((message) => {
let role: ChatCompletionMessageParam['role'], content: ChatCompletionMessageParam['content'];
if (message._getType() === "human") {
role = "user";
} else if (message._getType() === "ai") {
role = "assistant";
} else if (message._getType() === "system") {
role = "system";
} else {
throw new Error(
`Unsupported message type for AIMask: ${message._getType()}`
);
}
if (typeof message.content === "string") {
content = message.content;
} else {
throw new Error('unsupported content type')
}
return { role, content }
})
}

/**
* @example
* ```typescript
* // Initialize the ChatAIMask model with the path to the model binary file.
* const model = new ChatAIMask({
* modelId: "Mistral-7B-Instruct-v0.2-q4f16_1",
* });
*
* // Call the model with a message and await the response.
* const response = await model.call([
* new HumanMessage({ content: "My name is John." }),
* ]);
*
* // Log the response to the console.
* console.log({ response });
*
* ```
*/
export class ChatAIMask extends SimpleChatModel<AIMaskCallOptions> {
static inputs: AIMaskInputs;

protected _aiMaskClient: AIMaskClient;

modelId: string;
temperature?: number;

static lc_name() {
return "ChatAIMask";
}

constructor(inputs: AIMaskInputs) {
super(inputs);

this._aiMaskClient = inputs?.aiMaskClient ?? new AIMaskClient({ name: inputs?.appName })

this.modelId = inputs.modelId
this.temperature = inputs.temperature;
}

_llmType() {
return "ai-mask";
}

async *_streamResponseChunks(
messages: BaseMessage[],
): AsyncGenerator<ChatGenerationChunk> {
const stream = await this._aiMaskClient.chat({
messages: convertMessages(messages),
temperature: this.temperature,
}, {
modelId: this.modelId,
stream: true,
})

for await (const chunk of stream) {
const text = chunk
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({
content: text,
}),
});
}
return stream;
}

async _call(
messages: BaseMessage[],
): Promise<string> {
try {
const completion = await this._aiMaskClient.chat({
messages: convertMessages(messages),
temperature: this.temperature,
}, {
modelId: this.modelId,
})
return completion;
} catch (e) {
throw new Error("Error getting prompt completion.");
}
}
}
108 changes: 108 additions & 0 deletions app/lib/embeddings/ai-mask-embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { Pipeline, pipeline } from "@xenova/transformers";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { AIMaskClient } from '@ai-mask/sdk';

export interface AIMaskEmbeddingsParams
extends EmbeddingsParams {
/** Model name to use */
modelName: string;

/**
* Timeout to use when making requests to OpenAI.
*/
timeout?: number;

/**
* The maximum number of documents to embed in a single request.
*/
batchSize?: number;

/**
* Whether to strip new lines from the input text. This is recommended by
* OpenAI, but may not be suitable for all use cases.
*/
stripNewLines?: boolean;
aiMaskClient?: AIMaskClient
appName?: string
}

/**
* @example
* ```typescript
* const model = new HuggingFaceTransformersEmbeddings({
* modelName: "Xenova/all-MiniLM-L6-v2",
* });
*
* // Embed a single query
* const res = await model.embedQuery(
* "What would be a good company name for a company that makes colorful socks?"
* );
* console.log({ res });
*
* // Embed multiple documents
* const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]);
* console.log({ documentRes });
* ```
*/
export class AIMaskEmbeddings
extends Embeddings
implements AIMaskEmbeddingsParams {
modelName = "Xenova/all-MiniLM-L6-v2";

batchSize = 512;

stripNewLines = true;

timeout?: number;

protected _aiMaskClient: AIMaskClient;

constructor(fields?: Partial<AIMaskEmbeddingsParams>) {
super(fields ?? {});

this.modelName = fields?.modelName ?? this.modelName;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.timeout = fields?.timeout;

this._aiMaskClient = fields?.aiMaskClient ?? new AIMaskClient({ name: fields?.appName })
}

async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);

const batchRequests = batches.map((batch) => this.runEmbedding(batch));
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];

for (let i = 0; i < batchResponses.length; i += 1) {
const batchResponse = batchResponses[i];
for (let j = 0; j < batchResponse.length; j += 1) {
embeddings.push(batchResponse[j]);
}
}

return embeddings;
}

async embedQuery(text: string): Promise<number[]> {
const data = await this.runEmbedding([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);
return data[0];
}

private async runEmbedding(texts: string[]) {
return this.caller.call(async () => {
const output = await this._aiMaskClient.featureExtraction(
{ texts, pooling: "mean", normalize: true, },
{ modelId: this.modelName }
);
console.log({ output });
return output
});
}
}
Loading