Skip to content

Commit 31af432

Browse files
chore: move maia2 onnx models + add stateful maia hook
1 parent 26a2e40 commit 31af432

File tree

9 files changed

+108
-49
lines changed

9 files changed

+108
-49
lines changed

src/hooks/useAnalysisController/useAnalysisController.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ export const useAnalysisController = (
4444
})
4545
}
4646

47-
const maia = useMaiaEngine()
47+
const { maia, status } = useMaiaEngine()
4848
const engine = useStockfishEngine(parseStockfishEvaluation)
49-
const [currentMove, setCurrentMove] = useState<null | [string, string]>(null)
49+
const [currentMove, setCurrentMove] = useState<[string, string] | null>()
5050
const [stockfishEvaluations, setStockfishEvaluations] = useState<
5151
StockfishEvaluation[]
5252
>([])
@@ -59,8 +59,7 @@ export const useAnalysisController = (
5959
const board = new Chess(game.moves[controller.currentIndex].board)
6060

6161
;(async () => {
62-
if (maia?.status !== 'ready' || maiaEvaluations[controller.currentIndex])
63-
return
62+
if (status !== 'ready' || maiaEvaluations[controller.currentIndex]) return
6463

6564
const { result } = await maia.batchEvaluate(
6665
Array(9).fill(board.fen()),
@@ -87,7 +86,7 @@ export const useAnalysisController = (
8786
return newEvaluations
8887
})
8988
})()
90-
}, [controller.currentIndex, game.type, maia?.status])
89+
}, [controller.currentIndex, game.type, status])
9190

9291
useEffect(() => {
9392
if (game.type === 'tournament') return

src/utils/maia2/model.ts src/hooks/useMaiaEngine/model.ts

+65-21
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,74 @@
1+
import { MaiaStatus } from 'src/types'
12
import { InferenceSession, Tensor } from 'onnxruntime-web'
23

34
import { mirrorMove, preprocess, allPossibleMovesReversed } from './utils'
45

6+
interface MaiaOptions {
7+
model: string
8+
type: 'rapid' | 'blitz'
9+
setStatus: (status: MaiaStatus) => void
10+
setProgress: (progress: number) => void
11+
setError: (error: string) => void
12+
}
13+
514
class Maia {
6-
public model!: InferenceSession
7-
public type: 'rapid' | 'blitz'
8-
public status: 'loading' | 'no-cache' | 'downloading' | 'ready'
9-
10-
constructor(options: { model: string; type: 'rapid' | 'blitz' }) {
11-
this.status = 'loading'
12-
this.type = options.type ?? 'rapid'
13-
;(async () => {
14-
try {
15-
console.log('Getting cached')
16-
const buffer = await this.getCachedModel(options.model, options.type)
17-
await this.initializeModel(buffer)
18-
} catch (e) {
19-
console.log('Missing cache')
20-
this.status = 'no-cache'
21-
}
22-
})()
15+
private model!: InferenceSession
16+
private type: 'rapid' | 'blitz'
17+
private modelUrl: string
18+
private options: MaiaOptions
19+
20+
constructor(options: MaiaOptions) {
21+
this.type = options.type
22+
this.modelUrl = options.model
23+
this.options = options
24+
25+
this.initialize()
2326
}
2427

25-
public getStatus() {
26-
return this.status
28+
private async initialize() {
29+
try {
30+
const buffer = await this.getCachedModel(this.modelUrl, this.type)
31+
await this.initializeModel(buffer)
32+
this.options.setStatus('ready')
33+
} catch (e) {
34+
this.options.setStatus('no-cache')
35+
}
36+
}
37+
38+
public async downloadModel() {
39+
const response = await fetch(this.modelUrl)
40+
if (!response.ok) throw new Error('Failed to fetch model')
41+
42+
const reader = response.body?.getReader()
43+
const contentLength = +(response.headers.get('Content-Length') ?? 0)
44+
45+
if (!reader) throw new Error('No response body')
46+
47+
const chunks: Uint8Array[] = []
48+
let receivedLength = 0
49+
50+
while (true) {
51+
const { done, value } = await reader.read()
52+
if (done) break
53+
54+
chunks.push(value)
55+
receivedLength += value.length
56+
57+
this.options.setProgress((receivedLength / contentLength) * 100)
58+
}
59+
60+
const buffer = new Uint8Array(receivedLength)
61+
let position = 0
62+
for (const chunk of chunks) {
63+
buffer.set(chunk, position)
64+
position += chunk.length
65+
}
66+
67+
const cache = await caches.open(`MAIA2-${this.type.toUpperCase()}-MODEL`)
68+
await cache.put(this.modelUrl, new Response(buffer.buffer))
69+
70+
await this.initializeModel(buffer.buffer)
71+
this.options.setStatus('ready')
2772
}
2873

2974
public async getCachedModel(
@@ -52,8 +97,7 @@ class Maia {
5297

5398
public async initializeModel(buffer: ArrayBuffer) {
5499
this.model = await InferenceSession.create(buffer)
55-
this.status = 'ready'
56-
console.log('initialized')
100+
this.options.setStatus('ready')
57101
}
58102

59103
/**
+31-11
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,38 @@
1-
import { useState, useMemo, useEffect } from 'react'
2-
3-
import Maia from 'src/utils/maia2'
1+
import Maia from './model'
2+
import { MaiaStatus } from 'src/types'
3+
import { useState, useMemo } from 'react'
44

55
export const useMaiaEngine = () => {
6-
const [maia, setMaia] = useState<Maia>()
6+
const [status, setStatus] = useState<MaiaStatus>('loading')
7+
const [progress, setProgress] = useState(0)
8+
const [error, setError] = useState<string | null>(null)
79

8-
useEffect(() => {
9-
setMaia(new Maia({ model: '/maia2/maia_rapid.onnx', type: 'rapid' }))
10+
const maia = useMemo(() => {
11+
const model = new Maia({
12+
model: '/maia2/maia_rapid.onnx',
13+
type: 'rapid',
14+
setStatus: setStatus,
15+
setProgress: setProgress,
16+
setError: setError,
17+
})
18+
return model
1019
}, [])
1120

12-
// const maia = useMemo(() => {
13-
// const model = new Maia({ model: '/maia2/maia_rapid.onnx', type: 'rapid' })
14-
// return model
15-
// }, [])
21+
const downloadModel = async () => {
22+
try {
23+
setStatus('downloading')
24+
await maia.downloadModel()
25+
} catch (err) {
26+
setError(err instanceof Error ? err.message : 'Failed to download model')
27+
setStatus('error')
28+
}
29+
}
1630

17-
return maia
31+
return {
32+
maia,
33+
status,
34+
progress,
35+
error,
36+
downloadModel,
37+
}
1838
}
File renamed without changes.

src/pages/analysis/[...id].tsx

+1-9
Original file line numberDiff line numberDiff line change
@@ -577,15 +577,7 @@ const Analysis: React.FC<Props> = ({
577577
content="Collection of chess training and analysis tools centered around Maia."
578578
/>
579579
</Head>
580-
{maia?.status !== 'ready' ? (
581-
<>
582-
<div className="absolute left-0 top-0 z-50 flex h-screen w-screen flex-col bg-black">
583-
<p className="text-white">{maia?.status}</p>
584-
</div>
585-
</>
586-
) : (
587-
<></>
588-
)}
580+
589581
<GameControllerContext.Provider value={{ ...controller }}>
590582
{analyzedGame && (isMobile ? mobileLayout : desktopLayout)}
591583
</GameControllerContext.Provider>

src/types/analysis/index.ts

+7
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,10 @@ export interface StockfishEvaluation {
7272
cp_vec: { [key: string]: number }
7373
cp_relative_vec: { [key: string]: number }
7474
}
75+
76+
export type MaiaStatus =
77+
| 'loading'
78+
| 'no-cache'
79+
| 'downloading'
80+
| 'ready'
81+
| 'error'

src/utils/maia2/index.ts

-3
This file was deleted.

0 commit comments

Comments
 (0)