Skip to content

refactor(js): moved ALS code into separate node specific package #3061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ function maybeRegisterDynamicTools<
hasDynamicTools = true;
// Create a temporary registry with dynamic tools for the duration of this
// generate request.
registry = Registry.withParent(registry);
registry = registry.child();
}
registry.registerAction('tool', t as Action);
}
Expand Down
3 changes: 2 additions & 1 deletion js/ai/tests/formats/format_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { z } from '@genkit-ai/core';
import { NodeRegistry } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
Expand All @@ -24,7 +25,7 @@ describe('formats', () => {
let registry: Registry;

beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();
configureFormats(registry);
});

Expand Down
9 changes: 2 additions & 7 deletions js/ai/tests/formats/json_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { z } from '@genkit-ai/core';
import { NodeRegistry } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
Expand All @@ -29,12 +30,6 @@ import type {
import { defineProgrammableModel, runAsync } from '../helpers.js';

describe('jsonFormat', () => {
let registry: Registry;

beforeEach(() => {
registry = new Registry();
});

const streamingTests = [
{
desc: 'parses complete JSON object',
Expand Down Expand Up @@ -141,7 +136,7 @@ describe('jsonFormat e2e', () => {
let registry: Registry;

beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();
configureFormats(registry);
});

Expand Down
3 changes: 2 additions & 1 deletion js/ai/tests/generate/action_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { stripUndefinedProps, z } from '@genkit-ai/core';
import { NodeRegistry } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { readFileSync } from 'fs';
Expand Down Expand Up @@ -58,7 +59,7 @@ describe('spec', () => {
let pm: ProgrammableModel;

beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();
defineGenerateAction(registry);
pm = defineProgrammableModel(registry);
defineTool(
Expand Down
9 changes: 5 additions & 4 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { z, type PluginProvider } from '@genkit-ai/core';
import { NodeRegistry } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
Expand All @@ -32,7 +33,7 @@ import {
import { defineTool } from '../../src/tool.js';

describe('toGenerateRequest', () => {
const registry = new Registry();
const registry = new NodeRegistry();
// register tools
const tellAFunnyJoke = defineTool(
registry,
Expand Down Expand Up @@ -332,7 +333,7 @@ describe('generate', () => {
var echoModel: ModelAction;

beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();
echoModel = defineModel(
registry,
{
Expand Down Expand Up @@ -407,7 +408,7 @@ describe('generate', () => {
describe('generate', () => {
let registry: Registry;
beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();

defineModel(
registry,
Expand All @@ -432,7 +433,7 @@ describe('generate', () => {

describe('generateStream', () => {
it('should stream out chunks', async () => {
const registry = new Registry();
const registry = new NodeRegistry();

defineModel(
registry,
Expand Down
5 changes: 3 additions & 2 deletions js/ai/tests/model/middleware_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { z } from '@genkit-ai/core';
import { NodeRegistry } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
Expand Down Expand Up @@ -137,7 +138,7 @@ describe('validateSupport', () => {
});
});

const registry = new Registry();
const registry = new NodeRegistry();
configureFormats(registry);

const echoModel = defineModel(registry, { name: 'echo' }, async (req) => {
Expand Down Expand Up @@ -400,7 +401,7 @@ describe.only('simulateConstrainedGeneration', () => {
let registry: Registry;

beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();
configureFormats(registry);
});

Expand Down
4 changes: 2 additions & 2 deletions js/ai/tests/prompt/prompt_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { runWithContext, z, type ActionContext } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { NodeRegistry } from '@genkit-ai/core/node';
import { toJsonSchema } from '@genkit-ai/core/schema';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
Expand All @@ -34,7 +34,7 @@ describe('prompt', () => {
let registry;

beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();

defineEchoModel(registry);
defineTool(
Expand Down
3 changes: 2 additions & 1 deletion js/ai/tests/reranker/reranker_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { GenkitError, z } from '@genkit-ai/core';
import { NodeRegistry } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
Expand All @@ -25,7 +26,7 @@ describe('reranker', () => {
describe('defineReranker()', () => {
let registry: Registry;
beforeEach(() => {
registry = new Registry();
registry = new NodeRegistry();
});
it('reranks documents based on custom logic', async () => {
const customReranker = defineReranker(
Expand Down
10 changes: 5 additions & 5 deletions js/ai/tests/tool_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
*/

import { z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { NodeRegistry } from '@genkit-ai/core/node';
import * as assert from 'assert';
import { afterEach, describe, it } from 'node:test';
import { defineInterrupt, defineTool } from '../src/tool.js';

describe('defineInterrupt', () => {
let registry = new Registry();
let registry = new NodeRegistry();
registry.apiStability = 'beta';

afterEach(() => {
registry = new Registry();
registry = new NodeRegistry();
registry.apiStability = 'beta';
});

Expand Down Expand Up @@ -107,10 +107,10 @@ describe('defineInterrupt', () => {
});

describe('defineTool', () => {
let registry = new Registry();
let registry = new NodeRegistry();
registry.apiStability = 'beta';
afterEach(() => {
registry = new Registry();
registry = new NodeRegistry();
registry.apiStability = 'beta';
});

Expand Down
9 changes: 9 additions & 0 deletions js/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@
"require": "./lib/schema.js",
"import": "./lib/schema.mjs",
"default": "./lib/schema.js"
},
"./node": {
"types": "./lib/node.d.ts",
"require": "./lib/node.js",
"import": "./lib/node.mjs",
"default": "./lib/node.js"
}
},
"typesVersions": {
Expand All @@ -116,6 +122,9 @@
],
"schema": [
"lib/schema"
],
"node": [
"lib/node"
]
}
}
Expand Down
36 changes: 36 additions & 0 deletions js/core/src/als-async-store.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { AsyncLocalStorage } from 'node:async_hooks';
import { AsyncStore } from './registry.js';

/**
* Node AsyncLocalStorage based AsyncStore impl.
*/
export class AlsAsyncStore implements AsyncStore {
private asls: Record<string, AsyncLocalStorage<any>> = {};

getStore<T>(key: string): T | undefined {
return this.asls[key]?.getStore();
}

run<T, R>(key: string, store: T, callback: () => R): R {
if (!this.asls[key]) {
this.asls[key] = new AsyncLocalStorage<T>();
}
return this.asls[key].run(store, callback);
}
}
33 changes: 22 additions & 11 deletions js/core/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
* limitations under the License.
*/

import { AsyncLocalStorage } from 'node:async_hooks';
import type { z } from 'zod';
import { defineAction, type Action, type StreamingCallback } from './action.js';
import type { ActionContext } from './context.js';
import { Registry, type HasRegistry } from './registry.js';
import {
Registry,
_getAsyncStoreFactory,
type HasRegistry,
} from './registry.js';
import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js';

const legacyRegistryAlsKey = 'legacyRegistryAls';

/**
* Flow is an observable, streamable, (optionally) strongly typed function.
*/
Expand Down Expand Up @@ -132,18 +137,24 @@ function defineFlowAction<
metadata: config.metadata,
},
async (input, { sendChunk, context, trace }) => {
return await legacyRegistryAls.run(registry, () => {
const ctx = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).sendChunk = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).context = context;
(ctx as FlowSideChannel<z.infer<S>>).trace = trace;
return fn(input, ctx as FlowSideChannel<z.infer<S>>);
});
return await legacyRegistryAls().run(
legacyRegistryAlsKey,
registry,
() => {
const ctx = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).sendChunk = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).context = context;
(ctx as FlowSideChannel<z.infer<S>>).trace = trace;
return fn(input, ctx as FlowSideChannel<z.infer<S>>);
}
);
}
);
}

const legacyRegistryAls = new AsyncLocalStorage<Registry>();
function legacyRegistryAls() {
return _getAsyncStoreFactory()();
}

export function run<T>(
name: string,
Expand Down Expand Up @@ -191,7 +202,7 @@ export function run<T>(
}

if (!registry) {
registry = legacyRegistryAls.getStore();
registry = legacyRegistryAls().getStore(legacyRegistryAlsKey);
}
if (!registry) {
throw new Error(
Expand Down
1 change: 0 additions & 1 deletion js/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ export {
type FlowSideChannel,
} from './flow.js';
export * from './plugin.js';
export * from './reflection.js';
export { defineJsonSchema, defineSchema, type JSONSchema } from './schema.js';
export * from './telemetryTypes.js';
export * from './utils.js';
37 changes: 37 additions & 0 deletions js/core/src/node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { AlsAsyncStore } from './als-async-store.js';
import { _setAsyncStoreFactory, Registry } from './registry.js';
export * from './reflection.js';

export class NodeRegistry extends Registry {
constructor(parent?: NodeRegistry) {
if (parent) {
super(parent);
} else {
const store = new AlsAsyncStore();
const asyncStoreFactory = () => store;

super({ asyncStoreFactory });
_setAsyncStoreFactory(asyncStoreFactory);
}
}

child(): Registry {
return new NodeRegistry(this);
}
}
Loading