Skip to content

Commit 62fd0eb

Browse files
committed
Recover W2 novelty lane under the proven adaptive-clip stack
Port the W18 training and quantization defaults onto the older W2 pass-conditioned modulation lane so the next probe tests our own round-22 mechanism under a compliant, already-validated artifact path instead of re-running another near-1586 reproduction. The harness sync keeps local monitoring reliable while preserving the worker-facing launcher contract. Constraint: The next lane must preserve W2's pass-conditioned modulation story while staying under the 16 MB cap and using the fixed local evaluator path. Rejected: Re-run W19 with more seeds | single-seed result already underperformed W18 and still leaned on thin novelty Rejected: More W18-family quantization tuning | stronger score story but too close to open PR openai#1586 to solve the submission problem Confidence: medium Scope-risk: moderate Reversibility: clean Directive: Treat this branch as a W2-on-W18 hybrid; if the score improves, review novelty framing against both openai#1518 and openai#1586 before escalating to 3 seeds Tested: python3 -m py_compile train_gpt.py evaluate.py auto_resume_watch.py; python3 evaluate.py --list Not-tested: Live GPU eval on this hybrid lane Related: c0c2d68 Related: 7d435d2
1 parent a767e3d commit 62fd0eb

File tree

3 files changed

+33
-158
lines changed

3 files changed

+33
-158
lines changed

evaluate.py

Lines changed: 16 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
WORKSPACE = os.path.expanduser("~/autoresearch/pgolf")
2727
os.makedirs(WORKSPACE, exist_ok=True)
28-
HEARTBEAT_DIR = os.path.join(WORKSPACE, "heartbeats")
29-
os.makedirs(HEARTBEAT_DIR, exist_ok=True)
3028

3129
DEFAULT_THRESHOLD = float(os.environ.get("AUTORESEARCH_THRESHOLD", "1.1164"))
3230
DEFAULT_TIMEOUT = 2700 # 45 min
@@ -62,7 +60,13 @@ def _load_env():
6260
# ---------------------------------------------------------------------------
6361

6462
def _run(cmd, check=False, timeout=30):
65-
r = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=timeout)
63+
try:
64+
r = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=timeout)
65+
except subprocess.TimeoutExpired as e:
66+
stdout = e.stdout if isinstance(e.stdout, str) else (e.stdout or b"").decode("utf-8", "replace")
67+
stderr = e.stderr if isinstance(e.stderr, str) else (e.stderr or b"").decode("utf-8", "replace")
68+
stderr = (stderr + f"\nTIMEOUT after {timeout}s").strip()
69+
r = subprocess.CompletedProcess(cmd, 124, stdout=stdout, stderr=stderr)
6670
if check and r.returncode != 0:
6771
raise RuntimeError(f"Command failed: {cmd}\n{r.stderr}")
6872
return r
@@ -123,16 +127,16 @@ def _make_job_command(commit_sha, branch=None):
123127
VOCAB=$(python3 << 'PYEOF'
124128
import re, sys
125129
f = open('train_gpt.py').read()
126-
m = re.search(r'VOCAB_SIZE.*?,\s*(\d+)', f)
130+
m = re.search(r'VOCAB_SIZE.*?,\\s*(\\d+)', f)
127131
if m: print(m.group(1)); sys.exit()
128132
try:
129133
import lzma, base64
130-
m2 = re.search(r"b85decode\([b]?['\"](.+?)['\"]\)", f, re.DOTALL)
134+
m2 = re.search(r"b85decode\\([b]?['\\\"](.+?)['\\\"]\\)", f, re.DOTALL)
131135
if m2:
132136
blob = m2.group(1)
133137
try: code = lzma.decompress(base64.b85decode(blob)).decode()
134138
except: code = lzma.decompress(base64.b85decode(blob), format=lzma.FORMAT_RAW, filters=[{"id": lzma.FILTER_LZMA2}]).decode()
135-
m3 = re.search(r'VOCAB_SIZE.*?,\s*(\d+)', code)
139+
m3 = re.search(r'VOCAB_SIZE.*?,\\s*(\\d+)', code)
136140
if m3: print(m3.group(1)); sys.exit()
137141
except Exception: pass
138142
print('1024')
@@ -150,7 +154,7 @@ def _make_job_command(commit_sha, branch=None):
150154
SCYLLA_DIR="./data/datasets/fineweb10B_scylla"
151155
if [ ! -f "$SCYLLA_DIR/.download_complete" ]; then
152156
echo "data_setup: downloading Scylla data from HuggingFace..."
153-
pip install -q huggingface_hub 2>/dev/null || true
157+
PIP_NO_CACHE_DIR=1 pip install -q --no-cache-dir huggingface_hub 2>/dev/null || true
154158
python3 -c "from huggingface_hub import snapshot_download; snapshot_download('anthonym21/fineweb10B-scylla', local_dir='$SCYLLA_DIR', repo_type='dataset')"
155159
touch "$SCYLLA_DIR/.download_complete"
156160
echo "data_setup: Scylla download complete"
@@ -173,21 +177,24 @@ def _make_job_command(commit_sha, branch=None):
173177
"""
174178

175179
clone_setup = f"""
180+
rm -rf /workspace/pgolf /root/.cache/pip ~/.cache/pip /tmp/pip-cache
181+
mkdir -p /workspace
176182
if [ -n "$PGOLF_GIT_TOKEN" ]; then
177183
CLONE_URL="https://x-access-token:${{PGOLF_GIT_TOKEN}}@github.com/{owner}/{repo}.git"
178184
else
179185
CLONE_URL="{REPO_URL}"
180186
fi
181-
GIT_TERMINAL_PROMPT=0 git clone --quiet "$CLONE_URL" /workspace/pgolf
187+
GIT_TERMINAL_PROMPT=0 git clone --quiet --filter=blob:none --no-tags "$CLONE_URL" /workspace/pgolf
182188
"""
183189

184190
return f"""set -e
185-
pip install -q sentencepiece huggingface-hub tiktoken zstandard brotli 2>/dev/null || true
191+
PIP_NO_CACHE_DIR=1 pip install -q --no-cache-dir sentencepiece huggingface-hub tiktoken zstandard brotli 2>/dev/null || true
186192
187193
{clone_setup}
188194
cd /workspace/pgolf
189195
git fetch origin {f'{branch}' if branch else '--all'}
190196
git checkout {commit_sha}
197+
rm -rf .git /root/.cache/pip ~/.cache/pip /tmp/pip-cache
191198
192199
export PYTHONUNBUFFERED=1
193200
@@ -333,77 +340,6 @@ def _log_path(job_id):
333340
return os.path.join(WORKSPACE, f"run_{job_id}.log")
334341

335342

336-
def _heartbeat_path(job_id=None):
337-
if job_id:
338-
return os.path.join(HEARTBEAT_DIR, f"{job_id}.json")
339-
return os.path.join(WORKSPACE, "heartbeat-latest.json")
340-
341-
342-
def _log_snapshot(log_file, max_tail_lines=8):
343-
if not log_file or not os.path.exists(log_file):
344-
return {
345-
"exists": False,
346-
"line_count": 0,
347-
"size_bytes": 0,
348-
"mtime": None,
349-
"tail": [],
350-
}
351-
st = os.stat(log_file)
352-
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
353-
lines = f.read().splitlines()
354-
return {
355-
"exists": True,
356-
"line_count": len(lines),
357-
"size_bytes": st.st_size,
358-
"mtime": int(st.st_mtime),
359-
"tail": lines[-max_tail_lines:],
360-
}
361-
362-
363-
def _heartbeat_state_label(job_status, log_snapshot, now_ts=None):
364-
now_ts = int(now_ts or time.time())
365-
if job_status == "queueing":
366-
return "queued"
367-
if job_status in ("completed", "failed", "stopped", "timeout"):
368-
return job_status
369-
if not log_snapshot["exists"] or log_snapshot["line_count"] == 0:
370-
return "starting"
371-
mtime = log_snapshot.get("mtime")
372-
if mtime is None:
373-
return "running"
374-
if now_ts - mtime <= 90:
375-
return "streaming"
376-
return "quiet-running"
377-
378-
379-
def _write_heartbeat(kind, job_id=None, job_name=None, status=None, branch=None,
380-
commit=None, log_file=None, started_at=None, extra=None):
381-
now = int(time.time())
382-
snapshot = _log_snapshot(log_file)
383-
payload = {
384-
"kind": kind,
385-
"job_id": job_id,
386-
"job_name": job_name,
387-
"status": status,
388-
"state_label": _heartbeat_state_label(status, snapshot, now_ts=now),
389-
"branch": branch,
390-
"commit": commit,
391-
"log_file": log_file,
392-
"started_at": started_at,
393-
"updated_at": now,
394-
"elapsed_s": None if started_at is None else max(0, now - int(started_at)),
395-
"log": snapshot,
396-
}
397-
if extra:
398-
payload.update(extra)
399-
for path in {_heartbeat_path(), _heartbeat_path(job_id)}:
400-
try:
401-
with open(path, "w", encoding="utf-8") as f:
402-
json.dump(payload, f, indent=2, sort_keys=True)
403-
except Exception:
404-
pass
405-
406-
407343
def _has_final_results_content(content):
408344
"""Return True only when the final metric for the active eval mode is present."""
409345
if "results_json" in content:
@@ -674,8 +610,6 @@ def _signal_handler(signum, frame):
674610
commit = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True).stdout.strip()
675611
_log(f"branch={branch} commit={commit[:7]}")
676612

677-
started_at = int(time.time())
678-
679613
# 2. Create job
680614
try:
681615
job_name, job_id = _create_job(commit, node_group=args.node_group, branch=branch)
@@ -688,17 +622,6 @@ def _signal_handler(signum, frame):
688622
_main_state["log_file"] = log_file
689623
log_thread = threading.Thread(target=_stream_job_logs, args=(job_id, log_file), daemon=True)
690624
log_thread.start()
691-
_write_heartbeat(
692-
kind="eval",
693-
job_id=job_id,
694-
job_name=job_name,
695-
status="created",
696-
branch=branch,
697-
commit=commit[:7],
698-
log_file=log_file,
699-
started_at=started_at,
700-
extra={"threshold": args.threshold},
701-
)
702625

703626
# 4. Poll job status independently
704627
start = time.time()
@@ -707,17 +630,6 @@ def _signal_handler(signum, frame):
707630
status = _get_job_status(job_name, job_id)
708631
elapsed = int(time.time() - start)
709632
_log(f"[{elapsed}s] {job_name}: {status}")
710-
_write_heartbeat(
711-
kind="eval",
712-
job_id=job_id,
713-
job_name=job_name,
714-
status=status,
715-
branch=branch,
716-
commit=commit[:7],
717-
log_file=log_file,
718-
started_at=started_at,
719-
extra={"threshold": args.threshold},
720-
)
721633

722634
if status in ("completed", "failed", "stopped"):
723635
log_thread.join(timeout=15)
@@ -736,17 +648,6 @@ def _signal_handler(signum, frame):
736648
_log(f"Timeout after {args.timeout}s, stopping job")
737649
_stop_job_safe(job_id)
738650
log_thread.join(timeout=5)
739-
_write_heartbeat(
740-
kind="eval",
741-
job_id=job_id,
742-
job_name=job_name,
743-
status="timeout",
744-
branch=branch,
745-
commit=commit[:7],
746-
log_file=log_file,
747-
started_at=started_at,
748-
extra={"threshold": args.threshold},
749-
)
750651
_output(False, error=f"job timeout after {args.timeout}s")
751652

752653
# 5. Parse results from log
@@ -1074,17 +975,6 @@ def preflight(node_group=None, commit=None):
1074975
job_name, job_id = _create_job(commit, ng, branch=branch)
1075976
log_file = _log_path(job_id)
1076977
_log(f"Job: {job_name} ({job_id})")
1077-
started_at = int(time.time())
1078-
_write_heartbeat(
1079-
kind="preflight",
1080-
job_id=job_id,
1081-
job_name=job_name,
1082-
status="created",
1083-
branch=branch,
1084-
commit=commit[:7],
1085-
log_file=log_file,
1086-
started_at=started_at,
1087-
)
1088978

1089979
# Stream + poll
1090980
stream_thread = threading.Thread(target=_stream_job_logs, args=(job_id, log_file), daemon=True)
@@ -1097,16 +987,6 @@ def preflight(node_group=None, commit=None):
1097987
status = _get_job_status(job_name, job_id)
1098988
elapsed = int(time.time() - start)
1099989
_log(f"[preflight] [{elapsed}s] {status}")
1100-
_write_heartbeat(
1101-
kind="preflight",
1102-
job_id=job_id,
1103-
job_name=job_name,
1104-
status=status,
1105-
branch=branch,
1106-
commit=commit[:7],
1107-
log_file=log_file,
1108-
started_at=started_at,
1109-
)
1110990
if status in ("completed", "failed", "stopped"):
1111991
break
1112992

@@ -1152,21 +1032,6 @@ def preflight(node_group=None, commit=None):
11521032
_log(f" OVERALL: {'PASS' if all_pass else 'FAIL'}")
11531033
if not all_pass:
11541034
_log(f" Failed checks: {[k for k,v in checks.items() if not v]}")
1155-
_write_heartbeat(
1156-
kind="preflight",
1157-
job_id=job_id,
1158-
job_name=job_name,
1159-
status="completed" if all_pass else status,
1160-
branch=branch,
1161-
commit=commit[:7],
1162-
log_file=log_file,
1163-
started_at=started_at,
1164-
extra={
1165-
"preflight_pass": all_pass,
1166-
"preflight_checks": checks,
1167-
"details": results,
1168-
},
1169-
)
11701035
return all_pass
11711036

11721037

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ datasets
99
tiktoken
1010
sentencepiece
1111
brotli
12+
flash-attn-3

train_gpt.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Hyperparameters:
1818
seed = int(os.environ.get("SEED", 1337))
1919
run_id = os.environ.get("RUN_ID", str(uuid.uuid4()))
2020
iterations = int(os.environ.get("ITERATIONS", 20000))
21-
warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.667))
21+
warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75))
2222
warmup_steps = int(os.environ.get("WARMUP_STEPS", 20))
2323
train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432))
2424
train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048))
@@ -55,7 +55,7 @@ class Hyperparameters:
5555
head_lr = float(os.environ.get("HEAD_LR", 0.008))
5656
tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03))
5757
tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005))
58-
matrix_lr = float(os.environ.get("MATRIX_LR", 0.022))
58+
matrix_lr = float(os.environ.get("MATRIX_LR", 0.026))
5959
scalar_lr = float(os.environ.get("SCALAR_LR", 0.02))
6060
muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97))
6161
muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5))
@@ -72,8 +72,8 @@ class Hyperparameters:
7272
muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95))
7373
adam_wd = float(os.environ.get("ADAM_WD", 0.02))
7474
muon_wd = float(os.environ.get("MUON_WD", 0.095))
75-
embed_wd = float(os.environ.get("EMBED_WD", 0.095))
76-
ema_decay = float(os.environ.get("EMA_DECAY", 0.997))
75+
embed_wd = float(os.environ.get("EMBED_WD", 0.085))
76+
ema_decay = float(os.environ.get("EMA_DECAY", 0.9965))
7777
ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1")))
7878
ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96))
7979
ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001))
@@ -98,9 +98,11 @@ class Hyperparameters:
9898
gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 64))
9999
gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 12.0))
100100
matrix_bits = int(os.environ.get("MATRIX_BITS", 6))
101-
embed_bits = int(os.environ.get("EMBED_BITS", 8))
101+
embed_bits = int(os.environ.get("EMBED_BITS", 7))
102102
matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85))
103-
embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1))
103+
embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 15.0))
104+
mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 12.0))
105+
attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0))
104106
distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
105107
rank = int(os.environ.get("RANK", "0"))
106108
world_size = int(os.environ.get("WORLD_SIZE", "1"))
@@ -1674,7 +1676,14 @@ def gptq_mixed_quantize(state_dict, hessians, h):
16741676
result[name] = t.to(torch.float16) if t.is_floating_point() else t
16751677
meta[name] = "passthrough (float16)"
16761678
continue
1677-
cs = h.embed_clip_sigmas if "tok_emb" in name else h.matrix_clip_sigmas
1679+
if "tok_emb" in name:
1680+
cs = h.embed_clip_sigmas
1681+
elif ".mlp." in name:
1682+
cs = h.mlp_clip_sigmas
1683+
elif ".attn." in name:
1684+
cs = h.attn_clip_sigmas
1685+
else:
1686+
cs = h.matrix_clip_sigmas
16781687
bits = h.embed_bits if "tok_emb" in name else h.matrix_bits
16791688
q, s = gptq_quantize_weight(
16801689
t, hessians[name], clip_sigmas=cs, clip_range=2 ** (bits - 1) - 1

0 commit comments

Comments
 (0)