Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions core/src/agents/base_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
*/

import {Content} from '@google/genai';
import {trace} from '@opentelemetry/api';
import {context, trace} from '@opentelemetry/api';

import {createEvent, Event} from '../events/event.js';

import {CallbackContext} from './callback_context.js';
import {InvocationContext} from './invocation_context.js';
import {runAsyncGeneratorWithOtelContext, traceAgentInvocation, tracer} from '../telemetry/tracing.js';

type SingleAgentCallback = (context: CallbackContext) =>
Promise<Content|undefined>|(Content|undefined);
Expand Down Expand Up @@ -124,34 +125,37 @@ export abstract class BaseAgent {
async *
runAsync(parentContext: InvocationContext):
AsyncGenerator<Event, void, void> {
const span = trace.getTracer('gcp.vertex.agent')
.startSpan(`agent_run [${this.name}]`);
const span = tracer.startSpan(`invoke_agent ${this.name}`);
const ctx = trace.setSpan(context.active(), span);
try {
const context = this.createInvocationContext(parentContext);

const beforeAgentCallbackEvent =
await this.handleBeforeAgentCallback(context);
if (beforeAgentCallbackEvent) {
yield beforeAgentCallbackEvent;
}

if (context.endInvocation) {
return;
}

for await (const event of this.runAsyncImpl(context)) {
yield event;
}

if (context.endInvocation) {
return;
}

const afterAgentCallbackEvent =
await this.handleAfterAgentCallback(context);
if (afterAgentCallbackEvent) {
yield afterAgentCallbackEvent;
}
yield* runAsyncGeneratorWithOtelContext<BaseAgent, Event>(ctx, this, async function* () {
const context = this.createInvocationContext(parentContext);

const beforeAgentCallbackEvent =
await this.handleBeforeAgentCallback(context);
if (beforeAgentCallbackEvent) {
yield beforeAgentCallbackEvent;
}

if (context.endInvocation) {
return;
}

traceAgentInvocation({agent: this, invocationContext: context});
for await (const event of this.runAsyncImpl(context)) {
yield event;
}

if (context.endInvocation) {
return;
}

const afterAgentCallbackEvent =
await this.handleAfterAgentCallback(context);
if (afterAgentCallbackEvent) {
yield afterAgentCallbackEvent;
}
});
} finally {
span.end();
}
Expand All @@ -167,10 +171,12 @@ export abstract class BaseAgent {
async *
runLive(parentContext: InvocationContext):
AsyncGenerator<Event, void, void> {
const span = trace.getTracer('gcp.vertex.agent')
.startSpan(`agent_run [${this.name}]`);
const span = tracer.startSpan(`invoke_agent ${this.name}`);
const ctx = trace.setSpan(context.active(), span);
try {
// TODO(b/425992518): Implement live mode.
yield* runAsyncGeneratorWithOtelContext<BaseAgent, Event>(ctx, this, async function* () {
// TODO(b/425992518): Implement live mode.
});
throw new Error('Live mode is not implemented yet.');
} finally {
span.end();
Expand Down
69 changes: 59 additions & 10 deletions core/src/agents/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/

// TODO - b/436079721: implement traceMergedToolCalls, traceToolCall, tracer.
import {Content, createUserContent, FunctionCall, Part} from '@google/genai';

import {InvocationContext} from '../agents/invocation_context.js';
Expand All @@ -17,6 +16,7 @@ import {randomUUID} from '../utils/env_aware_utils.js';
import {logger} from '../utils/logger.js';

import {SingleAfterToolCallback, SingleBeforeToolCallback} from './llm_agent.js';
import {traceMergedToolCalls, tracer, traceToolCall} from '../telemetry/tracing.js';

const AF_FUNCTION_CALL_ID_PREFIX = 'adk-';
export const REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential';
Expand Down Expand Up @@ -191,11 +191,54 @@ async function callToolAsync(
args: Record<string, any>,
toolContext: ToolContext,
): Promise<any> {
// TODO - b/436079721: implement [tracer.start_as_current_span]
logger.debug(`callToolAsync ${tool.name}`);
return await tool.runAsync({args, toolContext});
return tracer.startActiveSpan(`execute_tool ${tool.name}`, async (span) => {
try {
logger.debug(`callToolAsync ${tool.name}`);
const result = await tool.runAsync({args, toolContext});
traceToolCall({
tool,
args,
functionResponseEvent: buildResponseEvent(tool, result, toolContext, toolContext.invocationContext)
})
return result;
} finally {
span.end();
}
});
}

function buildResponseEvent(
tool: BaseTool,
functionResult: any,
toolContext: ToolContext,
invocationContext: InvocationContext,
): Event {
let responseResult = functionResult;
if (typeof functionResult !== 'object' || functionResult == null) {
responseResult = {result: functionResult};
}

const partFunctionResponse: Part = {
functionResponse: {
name: tool.name,
response: responseResult,
id: toolContext.functionCallId,
},
};

const content: Content = {
role: 'user',
parts: [partFunctionResponse],
};

return createEvent({
invocationId: invocationContext.invocationId,
author: invocationContext.agent.name,
content: content,
actions: toolContext.actions,
branch: invocationContext.branch,
});
}
/**
* Handles function calls.
* Runtime behavior to pay attention to:
Expand Down Expand Up @@ -425,12 +468,18 @@ export async function handleFunctionCallList({
mergeParallelFunctionResponseEvents(functionResponseEvents);

if (functionResponseEvents.length > 1) {
// TODO - b/436079721: implement [tracer.start_as_current_span]
logger.debug('execute_tool (merged)');
// TODO - b/436079721: implement [traceMergedToolCalls]
logger.debug('traceMergedToolCalls', {
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent.id,
tracer.startActiveSpan('execute_tool (merged)', (span) => {
try {
logger.debug('execute_tool (merged)');
// TODO - b/436079721: implement [traceMergedToolCalls]
logger.debug('traceMergedToolCalls', {
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent.id,
});
traceMergedToolCalls({responseEventId: mergedEvent.id, functionResponseEvent: mergedEvent});
} finally {
span.end();
}
});
}
return mergedEvent;
Expand Down
20 changes: 15 additions & 5 deletions core/src/agents/llm_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

import {FunctionCall, GenerateContentConfig, Schema} from '@google/genai';
import {context, trace} from '@opentelemetry/api';
import {z} from 'zod';

import {createEvent, createNewEventId, Event, getFunctionCalls, getFunctionResponses, isFinalResponse} from '../events/event.js';
Expand All @@ -30,6 +31,7 @@ import {injectSessionState} from './instructions.js';
import {InvocationContext} from './invocation_context.js';
import {ReadonlyContext} from './readonly_context.js';
import {StreamingMode} from './run_config.js';
import {runAsyncGeneratorWithOtelContext, traceCallLlm, tracer} from '../telemetry/tracing.js';

/** An object that can provide an instruction string. */
export type InstructionProvider = (
Expand Down Expand Up @@ -1055,7 +1057,10 @@ export class LlmAgent extends BaseAgent {
author: this.name,
branch: invocationContext.branch,
});
for await (const llmResponse of this.callLlmAsync(
const span = tracer.startSpan('call_llm');
const ctx = trace.setSpan(context.active(), span);
yield* runAsyncGeneratorWithOtelContext<LlmAgent, Event>(ctx, this, async function* () {
for await (const llmResponse of this.callLlmAsync(
invocationContext, llmRequest, modelResponseEvent)) {
// ======================================================================
// Postprocess after calling the LLM
Expand All @@ -1066,8 +1071,10 @@ export class LlmAgent extends BaseAgent {
modelResponseEvent.id = createNewEventId();
modelResponseEvent.timestamp = new Date().getTime();
yield event;
}
}
}
});
span.end();
}

private async *
Expand Down Expand Up @@ -1217,7 +1224,6 @@ export class LlmAgent extends BaseAgent {

// Calls the LLM.
const llm = this.canonicalModel;
// TODO - b/436079721: Add tracer.start_as_current_span('call_llm')
if (invocationContext.runConfig?.supportCfc) {
// TODO - b/425992518: Implement CFC call path
// This is a hack, underneath it calls runLive. Which makes
Expand All @@ -1234,8 +1240,12 @@ export class LlmAgent extends BaseAgent {
for await (const llmResponse of this.runAndHandleError(
responsesGenerator, invocationContext, llmRequest,
modelResponseEvent)) {
// TODO - b/436079721: Add trace_call_llm

traceCallLlm({
invocationContext,
eventId: modelResponseEvent.id,
llmRequest,
llmResponse,
});
// Runs after_model_callback if it exists.
const alteredLlmResponse = await this.handleAfterModelCallback(
invocationContext, llmResponse, modelResponseEvent);
Expand Down
Loading