Skip to content

Commit 0b9426f

Browse files
committed
style: format code
1 parent ca95b8f commit 0b9426f

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

torch_frame/utils/skorch.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,11 @@
1-
import skorch.utils
2-
3-
# TODO: make it more safe
4-
old_to_tensor = skorch.utils.to_tensor
5-
6-
def to_tensor(X, device, accept_sparse=False):
7-
if isinstance(X, TensorFrame):
8-
return X
9-
return old_to_tensor(X, device, accept_sparse)
10-
11-
skorch.utils.to_tensor = to_tensor
121
import importlib
13-
importlib.reload(skorch.net)
14-
152
from typing import Any
163

17-
import pandas as pd
4+
import skorch.utils
185
import torch
19-
import torch.nn as nn
206
from numpy.typing import ArrayLike
217
from pandas import DataFrame
22-
from skorch import NeuralNet, NeuralNetClassifier
23-
from skorch.dataset import Dataset as SkorchDataset
8+
from skorch import NeuralNet
249
from torch import Tensor
2510

2611
import torch_frame
@@ -29,20 +14,34 @@ def to_tensor(X, device, accept_sparse=False):
2914
TextEmbedderConfig,
3015
TextTokenizerConfig,
3116
)
32-
from torch_frame.data.dataset import DataFrameToTensorFrameConverter, Dataset
17+
from torch_frame.data.dataset import Dataset
3318
from torch_frame.data.loader import DataLoader
3419
from torch_frame.data.tensor_frame import TensorFrame
3520
from torch_frame.typing import IndexSelectType
3621
from torch_frame.utils import infer_df_stype
3722

23+
# TODO: make it more safe
24+
old_to_tensor = skorch.utils.to_tensor
25+
26+
27+
def to_tensor(X, device, accept_sparse=False):
28+
if isinstance(X, TensorFrame):
29+
return X
30+
return old_to_tensor(X, device, accept_sparse)
31+
32+
33+
skorch.utils.to_tensor = to_tensor
34+
35+
importlib.reload(skorch.net)
36+
3837

3938
class NeuralNetPytorchFrameDataLoader(DataLoader):
4039
def __init__(self, dataset: Dataset | TensorFrame, *args,
4140
device: torch.device, **kwargs):
4241
super().__init__(dataset, *args, **kwargs)
4342
self.device = device
4443

45-
def collate_fn(
44+
def collate_fn( # type: ignore
4645
self, index: IndexSelectType) -> tuple[TensorFrame, Tensor | None]:
4746
index = torch.tensor(index)
4847
res = super().collate_fn(index).to(self.device)

0 commit comments

Comments
 (0)