1
- import skorch .utils
2
-
3
- # TODO: make it more safe
4
- old_to_tensor = skorch .utils .to_tensor
5
-
6
- def to_tensor (X , device , accept_sparse = False ):
7
- if isinstance (X , TensorFrame ):
8
- return X
9
- return old_to_tensor (X , device , accept_sparse )
10
-
11
- skorch .utils .to_tensor = to_tensor
12
1
import importlib
13
- importlib .reload (skorch .net )
14
-
15
2
from typing import Any
16
3
17
- import pandas as pd
4
+ import skorch . utils
18
5
import torch
19
- import torch .nn as nn
20
6
from numpy .typing import ArrayLike
21
7
from pandas import DataFrame
22
- from skorch import NeuralNet , NeuralNetClassifier
23
- from skorch .dataset import Dataset as SkorchDataset
8
+ from skorch import NeuralNet
24
9
from torch import Tensor
25
10
26
11
import torch_frame
@@ -29,20 +14,34 @@ def to_tensor(X, device, accept_sparse=False):
29
14
TextEmbedderConfig ,
30
15
TextTokenizerConfig ,
31
16
)
32
- from torch_frame .data .dataset import DataFrameToTensorFrameConverter , Dataset
17
+ from torch_frame .data .dataset import Dataset
33
18
from torch_frame .data .loader import DataLoader
34
19
from torch_frame .data .tensor_frame import TensorFrame
35
20
from torch_frame .typing import IndexSelectType
36
21
from torch_frame .utils import infer_df_stype
37
22
23
+ # TODO: make it more safe
24
+ old_to_tensor = skorch .utils .to_tensor
25
+
26
+
27
+ def to_tensor (X , device , accept_sparse = False ):
28
+ if isinstance (X , TensorFrame ):
29
+ return X
30
+ return old_to_tensor (X , device , accept_sparse )
31
+
32
+
33
+ skorch .utils .to_tensor = to_tensor
34
+
35
+ importlib .reload (skorch .net )
36
+
38
37
39
38
class NeuralNetPytorchFrameDataLoader (DataLoader ):
40
39
def __init__ (self , dataset : Dataset | TensorFrame , * args ,
41
40
device : torch .device , ** kwargs ):
42
41
super ().__init__ (dataset , * args , ** kwargs )
43
42
self .device = device
44
43
45
- def collate_fn (
44
+ def collate_fn ( # type: ignore
46
45
self , index : IndexSelectType ) -> tuple [TensorFrame , Tensor | None ]:
47
46
index = torch .tensor (index )
48
47
res = super ().collate_fn (index ).to (self .device )
0 commit comments