3
3
from pathlib import Path
4
4
import subprocess
5
5
from dataclasses import dataclass
6
- from typing import Optional , Tuple
6
+ from typing import Optional , Tuple , Dict
7
7
from enum import Enum , auto
8
8
9
+ from sharktank .utils .hf_datasets import Dataset , RemoteFile , get_dataset
10
+
9
11
logger = logging .getLogger (__name__ )
10
12
11
13
14
+ class AccuracyValidationException (RuntimeError ):
15
+ """Exception raised when accuracy validation fails."""
16
+
17
+ pass
18
+
19
+
12
20
class ModelSource (Enum ):
13
21
HUGGINGFACE = auto ()
14
22
LOCAL = auto ()
@@ -34,13 +42,17 @@ class ModelConfig:
34
42
batch_sizes : Tuple [int , ...]
35
43
device_settings : "DeviceSettings"
36
44
source : ModelSource
45
+ dataset_name : Optional [str ] = None # Name of the dataset in hf_datasets.py
37
46
repo_id : Optional [str ] = None
38
47
local_path : Optional [Path ] = None
39
48
azure_config : Optional [AzureConfig ] = None
40
49
41
50
def __post_init__ (self ):
42
- if self .source == ModelSource .HUGGINGFACE and not self .repo_id :
43
- raise ValueError ("repo_id required for HuggingFace models" )
51
+ if self .source == ModelSource .HUGGINGFACE :
52
+ if not (self .dataset_name or self .repo_id ):
53
+ raise ValueError (
54
+ "Either dataset_name or repo_id required for HuggingFace models"
55
+ )
44
56
elif self .source == ModelSource .LOCAL and not self .local_path :
45
57
raise ValueError ("local_path required for local models" )
46
58
elif self .source == ModelSource .AZURE and not self .azure_config :
@@ -70,6 +82,8 @@ def __init__(self, base_dir: Path, config: ModelConfig):
70
82
def _get_model_dir (self ) -> Path :
71
83
"""Creates and returns appropriate model directory based on source."""
72
84
if self .config .source == ModelSource .HUGGINGFACE :
85
+ if self .config .dataset_name :
86
+ return self .base_dir / self .config .dataset_name .replace ("/" , "_" )
73
87
return self .base_dir / self .config .repo_id .replace ("/" , "_" )
74
88
elif self .config .source == ModelSource .LOCAL :
75
89
return self .base_dir / "local" / self .config .local_path .stem
@@ -82,15 +96,36 @@ def _get_model_dir(self) -> Path:
82
96
raise ValueError (f"Unsupported model source: { self .config .source } " )
83
97
84
98
def _download_from_huggingface (self ) -> Path :
85
- """Downloads model from HuggingFace."""
99
+ """Downloads model from HuggingFace using hf_datasets.py ."""
86
100
model_path = self .model_dir / self .config .model_file
87
101
if not model_path .exists ():
88
- logger .info (f"Downloading model { self .config .repo_id } from HuggingFace" )
89
- subprocess .run (
90
- f"huggingface-cli download --local-dir { self .model_dir } { self .config .repo_id } { self .config .model_file } " ,
91
- shell = True ,
92
- check = True ,
93
- )
102
+ if self .config .dataset_name :
103
+ logger .info (
104
+ f"Downloading model { self .config .dataset_name } using hf_datasets"
105
+ )
106
+ dataset = get_dataset (self .config .dataset_name )
107
+ downloaded_files = dataset .download (local_dir = self .model_dir )
108
+
109
+ # Find the model file in downloaded files
110
+ for file_id , paths in downloaded_files .items ():
111
+ for path in paths :
112
+ if path .name == self .config .model_file :
113
+ return path
114
+
115
+ raise ValueError (
116
+ f"Model file { self .config .model_file } not found in dataset { self .config .dataset_name } "
117
+ )
118
+ else :
119
+ logger .info (f"Downloading model { self .config .repo_id } from HuggingFace" )
120
+ # Create a temporary dataset for direct repo downloads
121
+ remote_file = RemoteFile (
122
+ file_id = "model" ,
123
+ repo_id = self .config .repo_id ,
124
+ filename = self .config .model_file ,
125
+ )
126
+ downloaded_paths = remote_file .download (local_dir = self .model_dir )
127
+ return downloaded_paths [0 ]
128
+
94
129
return model_path
95
130
96
131
def _copy_from_local (self ) -> Path :
@@ -132,14 +167,30 @@ def _download_from_azure(self) -> Path:
132
167
return model_path
133
168
134
169
def prepare_tokenizer (self ) -> Path :
135
- """Downloads and prepares tokenizer."""
170
+ """Downloads and prepares tokenizer using hf_datasets.py when possible ."""
136
171
tokenizer_path = self .model_dir / "tokenizer.json"
172
+
137
173
if not tokenizer_path .exists ():
138
- logger .info (f"Downloading tokenizer { self .config .tokenizer_id } " )
174
+ # First try to get tokenizer from dataset if available
175
+ if self .config .dataset_name :
176
+ dataset = get_dataset (self .config .dataset_name )
177
+ downloaded_files = dataset .download (local_dir = self .model_dir )
178
+
179
+ # Look for tokenizer files in downloaded files
180
+ for file_id , paths in downloaded_files .items ():
181
+ for path in paths :
182
+ if path .name == "tokenizer.json" :
183
+ return path
184
+
185
+ # Fall back to downloading from transformers if not found in dataset
186
+ logger .info (
187
+ f"Downloading tokenizer { self .config .tokenizer_id } using transformers"
188
+ )
139
189
from transformers import AutoTokenizer
140
190
141
191
tokenizer = AutoTokenizer .from_pretrained (self .config .tokenizer_id )
142
192
tokenizer .save_pretrained (self .model_dir )
193
+
143
194
return tokenizer_path
144
195
145
196
def export_model (self , weights_path : Path ) -> Tuple [Path , Path ]:
0 commit comments