Skip to content

yxtay/transformer-recommenders

Repository files navigation

transformer-recommenders

Transformer-based Recommender Models in PyTorch for MovieLens

Overview

This repository provides modular implementations of recommender systems using transformer architectures, matrix factorization, and sequential models. It is designed for research and experimentation on MovieLens data, with scalable data access and experiment tracking.

Architecture & Components

  • Core package: xfmr_rec/
    • data.py: Data loading and preprocessing (MovieLens, LanceDB)
    • models.py, mf/, seq/, seq_embedded/: Model architectures (MF, sequential, transformer)
    • losses.py: Custom loss functions (BPR, CCL, SSM, etc.)
    • metrics.py: Evaluation metrics
    • trainer.py: Training loop and experiment management (PyTorch Lightning)
    • service.py, deploy.py: Model serving and deployment utilities
  • Data:
    • data/: Raw and processed MovieLens datasets (Parquet format)
    • lance_db/: LanceDB format for fast retrieval
  • Experiment Logs:
    • lightning_logs/, mlruns/: Model checkpoints and experiment tracking (MLflow)

Installation

Requirements

  • Python 3.12+ (the project is developed and tested on 3.12)
  • The repository uses uv to manage virtual environments and tasks. See pyproject.toml for pinned dependencies.

Install dependencies with uv (recommended):

# set up the environment and install pinned deps
uv sync

Usage

Data preparation

This repo ships helper scripts to download and convert MovieLens datasets into Parquet and LanceDB formats.

Example: prepare MovieLens 1M (ml-1m) and write parquet files into data/:

# fetch, extract and convert to parquet
uv run data

If you already have the original files (for example ml-1m.zip), place them under data/ and uv run data will pick them up. Otherwise the script will download and extract the dataset.

Training

Training is implemented with PyTorch Lightning. The repository exposes several task entrypoints.

Common training commands:

# Train a sequential transformer model for 16 epochs
uv run seq_train fit --trainer.max_epochs 16

# Train a matrix factorization model
uv run mf_train fit --trainer.max_epochs 10

Check pyproject.toml entrypoints for available tasks and the xfmr_rec/ modules for model and trainer configuration.

Deployment and serving

The repository contains light-weight deployment utilities to run a retrieval service from a Lightning checkpoint.

# Serve a sequential model checkpoint on localhost
uv run python -m xfmr_rec.seq.deploy --ckpt_path <path/to/checkpoint.ckpt>

See xfmr_rec/service.py and xfmr_rec/deploy.py for convenience functions that load a checkpoint and expose a simple predict/retrieve API. The code uses LanceDB or parquet data for fast lookups when available.

Project conventions

  • Models are organized by type in subfolders (mf/, seq/, seq_embedded/) for extensibility.
  • Custom loss functions live in xfmr_rec/losses.py and are referenced by trainer hooks.
  • Experiment tracking is handled by PyTorch Lightning and MLflow; checkpoints and logs are stored in lightning_logs/ and mlruns/.
  • Data access is optimized using Parquet and (optionally) LanceDB for retrieval workloads.

Entrypoints

Task entrypoints are defined in pyproject.toml and wired to uv tasks. Typical entrypoints include:

  • data: datasets download and conversion utilities
  • mf_train, mf_deploy, mf_tune: matrix factorization training / deploy / tuning workflows
  • seq_train, seq_deploy, seq_tune: sequential / transformer training / deploy / tuning workflows
  • seq_embedded_train, seq_embedded_deploy: transformer (embedded) sequential workflows

Run uv run (without args) to list available tasks, or inspect pyproject.toml for concrete command mappings.

Development notes & troubleshooting

  • If you see dependency or Python version errors, confirm you are using Python 3.12 and run uv sync to recreate the virtual environment.
  • If training fails with out-of-memory errors, reduce trainer.batch_size or enable gradient accumulation via trainer.accumulate_grad_batches.
  • Use the Lightning logs folder (lightning_logs/) to inspect checkpoints and tensorboard summaries.

References

About

Transformer Recommender Models in PyTorch with MovieLens

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors 6