Skip to content

Commit f84b06e

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 88b999c + e8e6c9d commit f84b06e

File tree

11 files changed

+551
-53
lines changed

11 files changed

+551
-53
lines changed

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ python pretrain_classification.py -epochs 80 -steps 25 -batchsize 50 -priordump
4848
```
4949
This should take less than 5 min on a modern NVIDIA GPU (around 10 minutes on Macbook M4 Pro GPU and around 40 min on M4 Pro CPU).
5050

51-
#### Step by Step Explanation
51+
We also offer a pre-generated dataset containing 1.28M tables with 50 datapoints and 3 features each for regression [here](https://ml.informatik.uni-freiburg.de/research-artifacts/pfefferle/nanoTabPFN/50x3_1280k_regression.h5).
52+
53+
You can pretrain on it using `python pretrain_regressor.py`.
54+
55+
#### Step by Step Explanation (Classifier)
5256

5357
First we import our Architecture, Prior interface and training loop, etc.
5458
```python
@@ -58,6 +62,7 @@ from nanotabpfn.train import train
5862
from nanotabpfn.utils import get_default_device
5963
from nanotabpfn.interface import NanoTabPFNClassifier
6064
from torch.nn import CrossEntropyLoss
65+
from nanotabpfn.callbacks import ConsoleLoggerCallback
6166
```
6267
then we instantiate our model and loss criterion:
6368
```python
@@ -77,17 +82,15 @@ prior = PriorDumpDataLoader(filename='50x3_3_100k_classification.h5', num_steps=
7782
```
7883
and finally train our model:
7984
```python
80-
def epoch_callback(epoch, epoch_time, mean_loss, model):
81-
classifier = NanoTabPFNClassifier(model, device)
82-
# you can add your own eval code here that runs after every epoch
83-
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {mean_loss:5.2f}', flush=True)
84-
8585
trained_model, loss = train(
8686
model=model,
8787
prior=prior,
8888
criterion=criterion,
8989
epochs=80,
9090
device=device,
91-
epoch_callback=epoch_callback
91+
callbacks=[ConsoleLoggerCallback()]
9292
)
9393
```
94+
95+
### Creating your own datasets
96+
Check out the [tabularpriors](https://github.com/automl/tabularpriors/) repository to create your own data using publicly available priors.

nanotabpfn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from nanotabpfn.interface import NanoTabPFNClassifier
1+
from nanotabpfn.interface import NanoTabPFNClassifier, NanoTabPFNRegressor

nanotabpfn/callbacks.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class Callback(ABC):
5+
""" Abstract base class for callbacks."""
6+
7+
@abstractmethod
8+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
9+
"""
10+
Called at the end of each epoch.
11+
12+
Args:
13+
epoch (int): The current epoch number.
14+
epoch_time (float): Time of the epoch in seconds.
15+
loss (float): Mean loss for the epoch.
16+
model: The model being trained.
17+
**kwargs: Additional arguments.
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def close(self):
23+
"""
24+
Called to release any resources or perform cleanup.
25+
"""
26+
pass
27+
28+
29+
class BaseLoggerCallback(Callback):
30+
""" Abstract base class for logger callbacks. """
31+
pass
32+
33+
34+
class ConsoleLoggerCallback(BaseLoggerCallback):
35+
""" Logger callback that prints epoch information to the console. """
36+
37+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
38+
print(f'Epoch {epoch:5d} | Time {epoch_time:5.2f}s | Mean Loss {loss:5.2f}', flush=True)
39+
40+
def close(self):
41+
""" Nothing to clean up for print logger. """
42+
pass
43+
44+
45+
class TensorboardLoggerCallback(BaseLoggerCallback):
46+
""" Logger callback that logs epoch information to TensorBoard. """
47+
48+
def __init__(self, log_dir: str):
49+
from torch.utils.tensorboard import SummaryWriter
50+
self.writer = SummaryWriter(log_dir=log_dir)
51+
52+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
53+
self.writer.add_scalar('Loss/train', loss, epoch)
54+
self.writer.add_scalar('Time/epoch', epoch_time, epoch)
55+
56+
def close(self):
57+
self.writer.close()
58+
59+
60+
class WandbLoggerCallback(BaseLoggerCallback):
61+
""" Logger callback that logs epoch information to Weights & Biases. """
62+
63+
def __init__(self, project: str, name: str = None, config: dict = None, log_dir: str = None):
64+
"""
65+
Initializes a WandbLoggerCallback.
66+
67+
Args:
68+
project (str): The name of the wandb project.
69+
name (str, optional): The name of the run. Defaults to None.
70+
config (dict, optional): Configuration dictionary for the run. Defaults to None.
71+
log_dir (str, optional): Directory to save wandb logs. Defaults to None.
72+
"""
73+
try:
74+
import wandb
75+
self.wandb = wandb # store wandb module to avoid import if not used
76+
wandb.init(
77+
project=project,
78+
name=name,
79+
config=config,
80+
dir=log_dir,
81+
)
82+
except ImportError:
83+
raise ImportError("wandb is not installed. Install it with: pip install wandb") from e
84+
85+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
86+
log_dict = {'epoch': epoch, 'loss': loss, ' epoch_time': epoch_time}
87+
self.wandb.log(log_dict)
88+
89+
def close(self):
90+
self.wandb.finish()

nanotabpfn/evaluation.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import argparse
2+
3+
import numpy as np
4+
import openml
5+
import torch
6+
from openml.config import set_root_cache_directory
7+
from openml.tasks import TaskType
8+
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, r2_score
9+
from sklearn.preprocessing import LabelEncoder
10+
11+
from nanotabpfn.interface import NanoTabPFNRegressor, NanoTabPFNClassifier
12+
13+
TOY_TASKS_REGRESSION = [
14+
362443, # diabetes
15+
]
16+
17+
TOY_TASKS_CLASSIFICATION = [
18+
59, # iris
19+
2382, # wine
20+
9946, # breast_cancer
21+
]
22+
23+
@torch.no_grad()
24+
def get_openml_predictions(
25+
*,
26+
model: NanoTabPFNRegressor | NanoTabPFNClassifier,
27+
tasks: list[int] | str = "tabarena-v0.1",
28+
max_n_features=500,
29+
max_n_instances=10_000,
30+
classification: bool | None = None,
31+
cache_directory: str | None = None,
32+
):
33+
"""
34+
Evaluates a model on a set of OpenML tasks and returns predictions.
35+
36+
Retrieves datasets from OpenML, applies preprocessing, and evaluates the given model on each task.
37+
Returns true targets, predicted labels, and predicted probabilities for each dataset.
38+
39+
Args:
40+
model (NanoTabPFNRegressor | NanoTabPFNClassifier): A scikit-learn compatible model or classifier to be evaluated.
41+
tasks (list[int] | str, optional): A list of OpenML task IDs or the name of a benchmark suite.
42+
max_n_features (int, optional): Maximum number of features allowed for a task. Tasks exceeding this limit are skipped.
43+
max_n_instances (int, optional): Maximum number of instances allowed for a task. Tasks exceeding this limit are skipped.
44+
classification (bool | None, optional): Whether the model is a classifier (True) or regressor (False). If None, it is inferred from the model type.
45+
cache_directory (str | None, optional): Directory to save OpenML data. If None, default cache path is used.
46+
Returns:
47+
dict: A dictionary where keys are dataset names and values are tuples of (true targets, predicted labels, predicted probabilities).
48+
"""
49+
if classification is None:
50+
classification = isinstance(model, NanoTabPFNClassifier)
51+
52+
if cache_directory is not None:
53+
set_root_cache_directory(cache_directory)
54+
55+
if isinstance(tasks, str):
56+
benchmark_suite = openml.study.get_suite(tasks)
57+
task_ids = benchmark_suite.tasks
58+
else:
59+
task_ids = tasks
60+
61+
dataset_predictions = {}
62+
63+
for task_id in task_ids:
64+
task = openml.tasks.get_task(task_id, download_splits=False)
65+
66+
if classification and task.task_type_id != TaskType.SUPERVISED_CLASSIFICATION:
67+
continue # skip task, only classification
68+
if not classification and task.task_type_id != TaskType.SUPERVISED_REGRESSION:
69+
continue # skip task, only regression
70+
71+
dataset = task.get_dataset(download_data=False)
72+
73+
n_features = dataset.qualities["NumberOfFeatures"]
74+
n_instances = dataset.qualities["NumberOfInstances"]
75+
if n_features > max_n_features or n_instances > max_n_instances:
76+
continue # skip task, too big
77+
78+
_, folds, _ = task.get_split_dimensions()
79+
tabarena_light = True
80+
if tabarena_light:
81+
folds = 1 # code supports multiple folds but tabarena_light only has one
82+
repeat = 0 # code only supports one repeat
83+
targets = []
84+
predictions = []
85+
probabilities = []
86+
for fold in range(folds):
87+
X, y, categorical_indicator, attribute_names = dataset.get_data(
88+
target=task.target_name, dataset_format="dataframe"
89+
)
90+
train_indices, test_indices = task.get_train_test_split_indices(
91+
fold=fold, repeat=repeat
92+
)
93+
X_train = X.iloc[train_indices].to_numpy()
94+
y_train = y.iloc[train_indices].to_numpy()
95+
X_test = X.iloc[test_indices].to_numpy()
96+
y_test = y.iloc[test_indices].to_numpy()
97+
98+
if classification:
99+
label_encoder = LabelEncoder()
100+
y_train = label_encoder.fit_transform(y_train)
101+
y_test = label_encoder.transform(y_test)
102+
targets.append(y_test)
103+
104+
model.fit(X_train, y_train)
105+
y_pred = model.predict(X_test)
106+
predictions.append(y_pred)
107+
if classification:
108+
y_proba = model.predict_proba(X_test)
109+
if y_proba.shape[1] == 2: # binary classification
110+
y_proba = y_proba[:, 1]
111+
probabilities.append(y_proba)
112+
113+
y_pred = np.concatenate(predictions, axis=0)
114+
targets = np.concatenate(targets, axis=0)
115+
probabilities = np.concatenate(probabilities, axis=0) if len(probabilities) > 0 else None
116+
dataset_predictions[str(dataset.name)] = (targets, y_pred, probabilities)
117+
return dataset_predictions
118+
119+
120+
if __name__ == "__main__":
121+
parser = argparse.ArgumentParser()
122+
parser.add_argument("-model_type", type=str, choices=["regression", "classification"], required=True,
123+
help="Whether to use the regressor or classifier model")
124+
parser.add_argument("-checkpoint", type=str, default=None,
125+
help="Path to load the model weights from. If None, default weights are used.")
126+
parser.add_argument("-dist_path", type=str, default=None,
127+
help="Path to load the bucket edges for the support bar distribution from. Only needed for regression.")
128+
parser.add_argument("-tasks", type=str, default="tabarena-v0.1",
129+
choices=["tabarena-v0.1", "toy_tasks"], help="Which OpenML tasks to evaluate on.")
130+
parser.add_argument("-cache_directory", type=str, default=None,
131+
help="Directory to save OpenML data. If None, default cache path is used.")
132+
parser.add_argument("-max_n_features", type=int, default=500,
133+
help="Maximum number of features allowed for a task. Tasks exceeding this limit are skipped.")
134+
parser.add_argument("-max_n_instances", type=int, default=10_000,
135+
help="Maximum number of instances allowed for a task. Tasks exceeding this limit are skipped.")
136+
args = parser.parse_args()
137+
138+
if args.model_type == "classification":
139+
model = NanoTabPFNClassifier(model=args.checkpoint)
140+
else:
141+
model = NanoTabPFNRegressor(model=args.checkpoint, dist=args.dist_path)
142+
model.model.eval()
143+
144+
if args.tasks == "toy_tasks" and args.model_type == "regression":
145+
tasks = TOY_TASKS_REGRESSION
146+
elif args.tasks == "toy_tasks" and args.model_type == "classification":
147+
tasks = TOY_TASKS_CLASSIFICATION
148+
else:
149+
tasks = args.tasks
150+
151+
predictions = get_openml_predictions(
152+
model=model, tasks=tasks, max_n_features=args.max_n_features, max_n_instances=args.max_n_instances,
153+
classification=(args.model_type=="classification"), cache_directory=args.cache_directory
154+
)
155+
156+
for dataset_name, (y_true, y_pred, y_proba) in predictions.items():
157+
if args.model_type == "classification":
158+
acc = balanced_accuracy_score(y_true, y_pred)
159+
auc = roc_auc_score(y_true, y_proba, multi_class='ovr')
160+
print(f"Dataset: {dataset_name} | ROC AUC: {auc:.4f} | Balanced Accuracy: {acc:.4f}")
161+
else:
162+
r2 = r2_score(y_true, y_pred)
163+
print(f"Dataset: {dataset_name} | R2: {r2:.4f}")

0 commit comments

Comments
 (0)