-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
152 lines (115 loc) · 5.57 KB
/
models.py
File metadata and controls
152 lines (115 loc) · 5.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Typed Pydantic models for the Customer Support Triage environment.
All Action, Observation, and Reward types live here.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ── Enumerations ──────────────────────────────────────────────────────────────
class Priority(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
URGENT = "urgent"
class Department(str, Enum):
BILLING = "billing"
TECHNICAL = "technical"
SHIPPING = "shipping"
RETURNS = "returns"
GENERAL = "general"
ESCALATION = "escalation"
class Sentiment(str, Enum):
POSITIVE = "positive"
NEUTRAL = "neutral"
NEGATIVE = "negative"
ANGRY = "angry"
# ── Ticket model ──────────────────────────────────────────────────────────────
class SupportTicket(BaseModel):
ticket_id: str
subject: str
body: str
customer_name: str
customer_tier: str # "free" | "pro" | "enterprise"
created_at: str # ISO timestamp string
sentiment: Sentiment
tags: List[str] = Field(default_factory=list)
# Ground truth (hidden from agent in observation, used by grader)
_true_priority: Optional[Priority] = None
_true_department: Optional[Department] = None
_true_response: Optional[str] = None
# ── Action ────────────────────────────────────────────────────────────────────
class TriageAction(BaseModel):
"""
The agent's action for one ticket in the queue.
Fields
------
ticket_id : which ticket this action applies to
priority : assigned priority (low/medium/high/urgent)
department : routed department
response : draft reply to the customer (1-3 sentences)
needs_human : whether to flag for human review
reasoning : brief chain-of-thought (not scored, aids debugging)
"""
ticket_id: str
priority: Priority
department: Department
response: str = Field(..., min_length=10, max_length=500)
needs_human: bool = False
reasoning: str = Field(default="", max_length=300)
class BatchTriageAction(BaseModel):
"""Action covering the full queue (used by inference script)."""
actions: List[TriageAction]
# ── Observation ───────────────────────────────────────────────────────────────
class TicketObservation(BaseModel):
"""What the agent sees for a single ticket (ground-truth fields stripped)."""
ticket_id: str
subject: str
body: str
customer_name: str
customer_tier: str
created_at: str
sentiment: Sentiment
tags: List[str]
class TriageObservation(BaseModel):
"""Full observation returned by reset() and step()."""
queue: List[TicketObservation] # tickets still to process
processed: int # how many tickets done so far
total_tickets: int
task_name: str
step_number: int
time_remaining: int # pseudo-seconds budget remaining
# ── Reward breakdown ──────────────────────────────────────────────────────────
class RewardInfo(BaseModel):
priority_score: float # 0–1 correct priority assignment
routing_score: float # 0–1 correct department routing
response_quality: float # 0–1 response relevance & completeness
escalation_score: float # 0–1 correct human-escalation decisions
total: float # weighted composite
ticket_id: str
feedback: str # human-readable grader feedback
# ── Step result ───────────────────────────────────────────────────────────────
class StepResult(BaseModel):
observation: TriageObservation
reward: float
done: bool
info: Dict[str, Any] = Field(default_factory=dict)
# ── State snapshot ────────────────────────────────────────────────────────────
class EnvState(BaseModel):
task_name: str
step_number: int
total_tickets: int
processed: int
pending: int
episode_reward: float
reward_history: List[float]
actions_taken: List[Dict[str, Any]]
done: bool
# ── Reset request/response ────────────────────────────────────────────────────
class ResetRequest(BaseModel):
task: Optional[str] = None # "easy" | "medium" | "hard" — defaults to "easy"
seed: Optional[int] = None
class ResetResponse(BaseModel):
observation: TriageObservation
task: str
seed: int