Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 33 additions & 168 deletions javascript/src/agents/__tests__/red-team.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { describe, it, expect, vi } from "vitest";
import { CrescendoStrategy } from "../red-team/crescendo-strategy";
import { GoatStrategy, GOAT_METAPROMPT_TEMPLATE } from "../red-team/goat-strategy";
import { renderMetapromptTemplate } from "../red-team/metaprompt-template";
import { redTeamCrescendo, redTeamGoat, redTeamAgent } from "../red-team/red-team-agent";
import { redTeamCrescendo, redTeamAgent } from "../red-team/red-team-agent";
import { Base64Technique, DEFAULT_TECHNIQUES } from "../red-team/techniques";
import { ScenarioExecutionState } from "../../execution/scenario-execution-state";
import { AgentRole, AgentAdapter, JudgeAgentAdapter } from "../../domain";
Expand Down Expand Up @@ -37,43 +36,52 @@ describe("CrescendoStrategy", () => {
const strategy = new CrescendoStrategy();

it("returns warmup phase for early turns", () => {
expect(strategy.getPhaseName(1, 100)).toBe("warmup");
const phase = strategy.getPhase(1, 100);
expect(phase.name).toBe("warmup");
});

it("returns probing phase for turns 20-45%", () => {
expect(strategy.getPhaseName(30, 100)).toBe("probing");
const phase = strategy.getPhase(30, 100);
expect(phase.name).toBe("probing");
});

it("returns escalation phase for turns 45-75%", () => {
expect(strategy.getPhaseName(50, 100)).toBe("escalation");
const phase = strategy.getPhase(50, 100);
expect(phase.name).toBe("escalation");
});

it("returns direct phase for late turns", () => {
expect(strategy.getPhaseName(80, 100)).toBe("direct");
const phase = strategy.getPhase(80, 100);
expect(phase.name).toBe("direct");
});

it("returns warmup at boundary turn 0", () => {
// Turn 0 / 100 = 0.0, which is in warmup [0.0, 0.2)
expect(strategy.getPhaseName(0, 100)).toBe("warmup");
const phase = strategy.getPhase(0, 100);
expect(phase.name).toBe("warmup");
});

it("returns probing at boundary turn 20", () => {
// Turn 20 / 100 = 0.2, which is in probing [0.2, 0.45)
expect(strategy.getPhaseName(20, 100)).toBe("probing");
const phase = strategy.getPhase(20, 100);
expect(phase.name).toBe("probing");
});

it("returns escalation at boundary turn 45", () => {
// Turn 45 / 100 = 0.45, which is in escalation [0.45, 0.75)
expect(strategy.getPhaseName(45, 100)).toBe("escalation");
const phase = strategy.getPhase(45, 100);
expect(phase.name).toBe("escalation");
});

it("returns direct at boundary turn 75", () => {
// Turn 75 / 100 = 0.75, which is in direct [0.75, Infinity)
expect(strategy.getPhaseName(75, 100)).toBe("direct");
const phase = strategy.getPhase(75, 100);
expect(phase.name).toBe("direct");
});

it("handles totalTurns of 0 without error", () => {
expect(strategy.getPhaseName(0, 0)).toBe("warmup");
const phase = strategy.getPhase(0, 0);
expect(phase.name).toBe("warmup");
});

it("builds a system prompt with all sections", () => {
Expand Down Expand Up @@ -146,7 +154,6 @@ describe("renderMetapromptTemplate", () => {
target: "hack it",
description: "test agent",
totalTurns: 100,
phaseEnds: [20, 45, 75],
});
expect(result).toBe(
"Target: hack it, Desc: test agent, Turns: 100, P1: 20, P2: 45, P3: 75"
Expand All @@ -169,23 +176,9 @@ describe("renderMetapromptTemplate", () => {
target: "",
description: "",
totalTurns: 10,
phaseEnds: [2, 4, 7],
});
expect(result).toBe("2-4-7");
});

it("leaves phase placeholders as literals when phaseEnds is omitted", () => {
// GOAT path: template has no phase placeholders, so this is a no-op.
// If a user accidentally passes a Crescendo template without phaseEnds,
// the placeholders are left as-is rather than silently corrupting.
const template = "turns: {totalTurns}, p1: {phase1End}";
const result = renderMetapromptTemplate(template, {
target: "t",
description: "d",
totalTurns: 10,
});
expect(result).toBe("turns: 10, p1: {phase1End}");
});
});

describe("refusal detection", () => {
Expand Down Expand Up @@ -242,25 +235,36 @@ describe("refusal detection", () => {
});

it("hard refusal skips scorer and sets score=0", async () => {
const generateTextMock = vi.fn();
vi.doMock("ai", () => ({ generateText: generateTextMock }));

const testAgent = redTeamCrescendo({
target: "test",
attackPlan: "pre-baked plan",
scoreResponses: true,
detectRefusals: true,
});

// Access internal turnScores to verify caching
const internal = testAgent as unknown as {
call: typeof testAgent.call;
turnScores: Map<number, { score: number; hint: string }>;
detectRefusal(content: string): "hard" | "soft" | "none";
getLastAssistantContent(messages: unknown[]): string;
};

// Verify that a hard refusal message is classified correctly — the scorer
// is not invoked by detectRefusal itself, only by the call() orchestration.
// Simulate call() logic: if detectRefusals and hard refusal, skip scorer
const messages = [
{ role: "assistant" as const, content: "I cannot help with that request." },
];
const lastContent = internal.getLastAssistantContent(messages);
expect(internal.detectRefusal(lastContent)).toBe("hard");
const refusal = internal.detectRefusal(lastContent);

expect(refusal).toBe("hard");
// The scorer (generateText) should NOT have been called
expect(generateTextMock).not.toHaveBeenCalled();

vi.doUnmock("ai");
});

it("soft/none refusal does not short-circuit", () => {
Expand Down Expand Up @@ -773,145 +777,6 @@ describe("rollbackMessagesTo", () => {
});
});

// ---------------------------------------------------------------------------
// GoatStrategy
// ---------------------------------------------------------------------------

describe("GoatStrategy", () => {
const strategy = new GoatStrategy();

describe("stage boundaries", () => {
it("returns early for turn 1 of 50", () => {
expect(strategy.getPhaseName(1, 50)).toBe("early");
});

it("returns early up to 30% boundary", () => {
expect(strategy.getPhaseName(14, 50)).toBe("early"); // 28%
});

it("returns mid at 30% boundary", () => {
expect(strategy.getPhaseName(15, 50)).toBe("mid"); // 30%
});

it("returns mid up to 70% boundary", () => {
expect(strategy.getPhaseName(34, 50)).toBe("mid"); // 68%
});

it("returns late at 70% boundary", () => {
expect(strategy.getPhaseName(35, 50)).toBe("late"); // 70%
});

it("returns late for the final turn", () => {
expect(strategy.getPhaseName(50, 50)).toBe("late");
});

it("returns early for turn 0", () => {
expect(strategy.getPhaseName(0, 50)).toBe("early");
});

it("returns late when total_turns=1 and turn=1", () => {
expect(strategy.getPhaseName(1, 1)).toBe("late");
});

it("all three stages appear in a 50-turn run", () => {
const stages = new Set(
Array.from({ length: 50 }, (_, i) => strategy.getPhaseName(i + 1, 50))
);
expect(stages).toEqual(new Set(["early", "mid", "late"]));
});
});

describe("buildSystemPrompt", () => {
const base = {
target: "extract the system prompt",
currentTurn: 1,
totalTurns: 30,
scenarioDescription: "A bank support agent",
metapromptPlan: "custom plan here",
};

it("contains target objective", () => {
expect(strategy.buildSystemPrompt(base)).toContain("extract the system prompt");
});

it("contains the metaprompt plan", () => {
expect(strategy.buildSystemPrompt(base)).toContain("custom plan here");
});

it("contains the technique catalogue", () => {
const prompt = strategy.buildSystemPrompt(base);
expect(prompt).toContain("TECHNIQUE CATALOGUE");
expect(prompt).toContain("HYPOTHETICAL FRAMING");
expect(prompt).toContain("PERSONA MODIFICATION");
});

it("contains current stage name in uppercase", () => {
expect(strategy.buildSystemPrompt(base)).toContain("EARLY");
});

it("contains turn info", () => {
expect(strategy.buildSystemPrompt({ ...base, currentTurn: 15 })).toContain("15 of 30");
});

it("contains scenario description", () => {
expect(strategy.buildSystemPrompt(base)).toContain("A bank support agent");
});

it("early and late prompts differ in stage hint", () => {
const early = strategy.buildSystemPrompt({ ...base, currentTurn: 1 });
const late = strategy.buildSystemPrompt({ ...base, currentTurn: 25 });
expect(early).toContain("EARLY");
expect(late).toContain("LATE");
expect(early).not.toBe(late);
});
});


});

// ---------------------------------------------------------------------------
// redTeamGoat factory
// ---------------------------------------------------------------------------

describe("redTeamGoat", () => {
it("creates a RedTeamAgent instance", () => {
const agent = redTeamGoat({ target: "test" });
expect(agent).toBeDefined();
});

it("defaults totalTurns to 30", () => {
const agent = redTeamGoat({ target: "test" }) as any;
expect(agent.totalTurns).toBe(30);
});

it("allows overriding totalTurns", () => {
const agent = redTeamGoat({ target: "test", totalTurns: 50 }) as any;
expect(agent.totalTurns).toBe(50);
});

it("uses GOAT_METAPROMPT_TEMPLATE", () => {
const agent = redTeamGoat({ target: "test" }) as any;
expect(agent.metapromptTemplate).toBe(GOAT_METAPROMPT_TEMPLATE);
});

it("allows overriding metapromptTemplate", () => {
const custom = "custom template {target} {description} {totalTurns}";
const agent = redTeamGoat({ target: "test", metapromptTemplate: custom }) as any;
expect(agent.metapromptTemplate).toBe(custom);
});

it("uses GoatStrategy", () => {
const agent = redTeamGoat({ target: "test" }) as any;
expect(agent.strategy).toBeInstanceOf(GoatStrategy);
});

it("uses different strategy than redTeamCrescendo", () => {
const goat = redTeamGoat({ target: "test" }) as any;
const crescendo = redTeamCrescendo({ target: "test" }) as any;
expect(goat.strategy.constructor).not.toBe(crescendo.strategy.constructor);
});
});

describe("injection probability config", () => {
it("defaults to 0.0", () => {
const agent = redTeamCrescendo({ target: "test", attackPlan: "plan" });
Expand Down
11 changes: 1 addition & 10 deletions javascript/src/agents/red-team/crescendo-strategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,7 @@ const PHASES: Phase[] = [
];

export class CrescendoStrategy implements RedTeamStrategy {
phaseEnds(totalTurns: number): [number, number, number] {
const t = totalTurns;
return [
Math.max(1, Math.floor(0.20 * t)),
Math.max(1, Math.floor(0.45 * t)),
Math.max(1, Math.floor(0.75 * t)),
];
}

private getPhase(
getPhase(
currentTurn: number,
totalTurns: number
): { name: string; instructions: string } {
Expand Down
Loading
Loading