forked from Mercidaiha/IRT-Router
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_mirt.py
More file actions
186 lines (158 loc) · 6.33 KB
/
Copy pathtrain_mirt.py
File metadata and controls
186 lines (158 loc) · 6.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# train.py
import argparse
import logging
import pandas as pd
import numpy as np
from numpy._typing import NDArray
from typing import Any
from torch.utils.data import TensorDataset, DataLoader
from router import MIRT
import torch
import pickle
logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger(__name__)
emb_name = "bert"
def load_embeddings(
llm_embeddings_file: str,
query_embeddings_file: str,
llm_profiles_file: str,
queries_file: str,
) -> tuple[
dict[int, NDArray[Any]], dict[int, NDArray[Any]], dict[str, int], dict[str, int]
]:
logger.info(f"Loading llm embeddings")
with open(llm_embeddings_file, "rb") as f:
llm_embeddings_data = pickle.load(f)
llm_embeddings: dict[int, NDArray[Any]] = {
llm["index"]: np.array(llm["embedding"]) for llm in llm_embeddings_data
}
logger.info(f"Loading query embeddings")
with open(query_embeddings_file, "rb") as f:
query_embeddings_data = pickle.load(f)
query_embeddings: dict[int, NDArray[Any]] = {
query["index"]: np.array(query["embedding"]) for query in query_embeddings_data
}
logger.info(f"Loading llm id map")
# becomes something like: {'alt-anthropic/claude-sonnet-4-5-20250929': 1, 'alt-google/gemini-2.5-flash': 2, 'alt-google/gemini-2.5-pro': 3, ...}
llm_id_map = pd.read_csv(llm_profiles_file, index_col="name").to_dict()["index"]
logger.info(f"Loading query id map")
# becomes something like: {'What is the capital of France?': 1, 'What is the capital of Germany?': 2, 'What is the capital of Italy?': 3, ...}
query_id_map = pd.read_csv(queries_file, index_col="question").to_dict()["index"]
return llm_embeddings, query_embeddings, llm_id_map, query_id_map
def map_ids_to_vectors(
data: pd.DataFrame,
llm_embeddings: dict[int, NDArray[Any]],
query_embeddings: dict[int, NDArray[Any]],
llm_id_map: dict[str, int],
query_id_map: dict[str, int],
):
llm_vectors = []
query_vectors = []
# for each row in the dataset, map the llm by name to its embedding. then, map the question by its numerical ID to its embedding.
for _, row in data.iterrows():
# use the llm profile string to map to the integer ID
llm_id = llm_id_map[row["llm"]]
# use the question string to map to the integer ID
query_id = query_id_map[row["question"]]
# use the integer ID to map to the embedding of the profile
llm_vectors.append(llm_embeddings[llm_id])
# use the integer ID to map to the embedding of the question
query_vectors.append(query_embeddings[query_id])
return np.array(llm_vectors), np.array(query_vectors)
def train_mirt(
train_data_file: str,
test_data_file: str,
query_embeddings_file: str,
llm_embeddings_file: str,
llm_profiles_file: str,
queries_file: str,
output_file: str,
epochs: int,
lr: float,
device: str,
knowledge_n: int,
batch_size: int,
):
logger.info(f"Loading train and test data")
train_data = pd.read_csv(train_data_file)
test_data = pd.read_csv(test_data_file)
llm_embeddings, query_embeddings, llm_id_map, query_id_map = load_embeddings(
llm_embeddings_file, query_embeddings_file, llm_profiles_file, queries_file
)
logger.info(f"Mapping ids to vectors")
train_llm, train_query = map_ids_to_vectors(
train_data, llm_embeddings, query_embeddings, llm_id_map, query_id_map
)
test_llm, test_query = map_ids_to_vectors(
test_data, llm_embeddings, query_embeddings, llm_id_map, query_id_map
)
train_set = DataLoader(
TensorDataset(
torch.tensor(train_llm, dtype=torch.float32),
torch.tensor(train_query, dtype=torch.float32),
torch.tensor(train_data["performance"].values, dtype=torch.float32),
),
batch_size=batch_size,
shuffle=True,
)
test_set = DataLoader(
TensorDataset(
torch.tensor(test_llm, dtype=torch.float32),
torch.tensor(test_query, dtype=torch.float32),
torch.tensor(test_data["performance"].values, dtype=torch.float32),
),
batch_size=batch_size,
shuffle=False,
)
if emb_name == "open":
query_dim = 1536
llm_dim = 1536
elif emb_name == "zhipu":
query_dim = 512
llm_dim = 512
elif emb_name == "bge":
query_dim = 1024
llm_dim = 1024
elif emb_name == "bert":
query_dim = 768
llm_dim = 768
cdm = MIRT.MIRT(llm_dim, query_dim, knowledge_n)
cdm.train(train_set, test_set, epoch=epochs, device=device, lr=lr)
cdm.save(output_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train MIRT model.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python train_mirt.py --train-data-file train.csv --test-data-file test.csv --query-embeddings-file query_embeddings.pkl --llm-embeddings-file llm_embeddings.pkl --llm-profiles-file llm_profiles.csv --queries-file queries.csv --output-file mirt_bert.snapshot
""",
)
parser.add_argument("--train-data-file", help="Train data file")
parser.add_argument("--test-data-file", help="Test data file")
parser.add_argument("--query-embeddings-file", help="Query embeddings file")
parser.add_argument("--llm-embeddings-file", help="LLM embeddings file")
parser.add_argument("--llm-profiles-file", help="LLM profiles file")
parser.add_argument("--queries-file", help="Queries file")
parser.add_argument("--output-file", help="Output file")
# training arguments
parser.add_argument("--epochs", help="Epochs", type=int, default=9)
parser.add_argument("--lr", help="Learning rate", type=float, default=0.001)
parser.add_argument("--device", help="Device", type=str, default="cpu")
parser.add_argument("--knowledge-n", help="Knowledge n", type=int, default=25)
parser.add_argument("--batch-size", help="Batch size", type=int, default=512)
args = parser.parse_args()
train_mirt(
args.train_data_file,
args.test_data_file,
args.query_embeddings_file,
args.llm_embeddings_file,
args.llm_profiles_file,
args.queries_file,
args.output_file,
args.epochs,
args.lr,
args.device,
args.knowledge_n,
args.batch_size,
)