An end-to-end machine learning project that trains a model to read and translate handwritten math into LaTeX code.
This is a fine-tuned version of microsoft/trocr-base-handwritten
, a transformer-based optical character recognition model, adapted to work with handwritten math images and structured math syntax. You can find it on Hugging Face as tjoab/latex_finetuned
.
In this repo, you'll find:
- ๐งฑ Data + preprocessing pipeline from raw InkML files to model-ready image/label pairs
- ๐ง TrOCR fine-tuning using custom PyTorch training loop and
DataLoaders
- ๐พ Use of gradient accumulation + mixed precision to train on limited hardware
- ๐ Model logging and checkpointing for segmented training sessions
- ๐ฅ๏ธ Lightweight demo to showace model inference
- ๐ (Coming soon) Docker containerization for cloud deployment on AWS SageMaker
Most OCR systems perform well on natural language, but they struggle with mathematical notation โ especially when it's handwritten. LaTeXify aims to understand the structure of math.
Math expressions arenโt linear like natural text โ theyโre inherently 2D. Youโre not just translating symbols, youโre interpreting spatial relationships: superscripts, fractions, nested square roots, integrals with bounds, and multi-level subscripts. This makes math recognition fundamentally different from typical OCR or sequence-to-sequence tasks.
I wanted to directly output LaTeX from rasterized handwriting โ no intermediate character recognition, no symbol lookup, just end-to-end learning.
Training a transformer model is pretty demanding โ especially on commodity hardware. So to make this proccess more accessible to more people, I used a couple tricks and was able to train this model on a NVIDIA T4 with 16GM of VRAM.
- By using mixed precision (
torch.cuda.amp
)- Reduced RAM consumption by using
float16
where possible - Look out for
autocast()
andGradScaler()
calls inside oftrain/train.py
-
with autocast(): outputs = model(pixel_values=images, labels=labels) loss = outputs.loss / grad_accumulation_steps
- Reduced RAM consumption by using
- Small batche sizes are inherently noisey, and transformer models benefit more from larger batches
- But increasing batch size could cause memory issues
- Introduce gradient accumulation
- Enables a larger effective batch by accumulating gradients over several small batches, then updating model weights
- This improves the quality of our gradient signal without increasing peak memory load per step
- Essentially trading time for memory, because compute is cheap while memory is scarce
- Introduce gradient accumulation
- But increasing batch size could cause memory issues
This project uses pycairo
for rendering handwritten strokes. If you plan on using the DataLoader
from train/dataset.py
, you must install these system libraries prior to installing the Python dependencies:
sudo apt-get install -y libcairo2-dev libjpeg-dev libgif-dev
Otherwise, you can remove pycairo
from the requirements.txt
and run:
pip install -r requirements.txt
Would you like to change your training parameters, choose your own model to fine-tune, or toggle model checkpoints/logs? No need to touch any of the Python logic โ everything is driven from config files. Take a look in in train/train_config.yaml
or evaluation/eval_config.yaml
and make your changes.
model_name: microsoft/trocr-base-handwritten
data_dir: ./data/mathwriting-2024/train/
batch_size: 8
grad_accumulation: 8
learning_rate: 5e-5
warmup_steps: 1000
perform_logs: false
log_dir: ./train/logs/
I decided to evaluate performance using Character Error Rate (CER) which is defined below. It basically tells you what fraction of the characters in the target output were wrong โ either missing, incorrect, or extra.
CER = (Substitutions + Insertions + Deletions) / Total Characters in Ground Truth
Math expressions are structurally sensitive. Shuffling even a single character can completely change the meaning:
x^2
vs.x_2
\frac{a}{b}
vs.\frac{b}{a}
In the past I've worked with BLEU which is a sequence level metric, however I settled on CER because it penalizes small syntax error more harshly.
Evalution of tjoab/latex_finetuned
yeilded a CER of 14.9%.
- ๐ค HuggingFace Transformers โ for TrOCR and tokenizers
- ๐ฅ PyTorch โ for training loops, data loading, and AMP
- ๐ผ๏ธ Streamlit โ model demo (๐ click the link)