@@ -75,8 +75,7 @@ def set_handlers(
75
75
enabled_handler : typing .Callable [[bool , AsyncNextcloudApp | NextcloudApp ], typing .Awaitable [str ] | str ],
76
76
heartbeat_handler : typing .Callable [[], typing .Awaitable [str ] | str ] | None = None ,
77
77
init_handler : typing .Callable [[AsyncNextcloudApp | NextcloudApp ], typing .Awaitable [None ] | None ] | None = None ,
78
- models_to_fetch : list [str ] | None = None ,
79
- models_download_params : dict | None = None ,
78
+ models_to_fetch : dict [str , dict ] | None = None ,
80
79
map_app_static : bool = True ,
81
80
):
82
81
"""Defines handlers for the application.
@@ -92,7 +91,6 @@ def set_handlers(
92
91
93
92
.. note:: ```huggingface_hub`` package should be present for automatic models fetching.
94
93
95
- :param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
96
94
:param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
97
95
98
96
.. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
@@ -140,8 +138,7 @@ async def init_callback(
140
138
background_tasks .add_task (
141
139
__fetch_models_task ,
142
140
nc ,
143
- models_to_fetch if models_to_fetch else [],
144
- models_download_params if models_download_params else {},
141
+ models_to_fetch if models_to_fetch else {},
145
142
)
146
143
return responses .JSONResponse (content = {}, status_code = 200 )
147
144
@@ -181,8 +178,7 @@ def __map_app_static_folders(fast_api_app: FastAPI):
181
178
182
179
def __fetch_models_task (
183
180
nc : NextcloudApp ,
184
- models : list [str ],
185
- params : dict [str , typing .Any ],
181
+ models : dict [str , dict ],
186
182
) -> None :
187
183
if models :
188
184
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
@@ -193,10 +189,8 @@ def display(self, msg=None, pos=None):
193
189
nc .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
194
190
return super ().display (msg , pos )
195
191
196
- if "max_workers" not in params :
197
- params ["max_workers" ] = 2
198
- if "cache_dir" not in params :
199
- params ["cache_dir" ] = persistent_storage ()
200
192
for model in models :
201
- snapshot_download (model , tqdm_class = TqdmProgress , ** params ) # noqa
193
+ workers = models [model ].pop ("max_workers" , 2 )
194
+ cache = models [model ].pop ("cache_dir" , persistent_storage ())
195
+ snapshot_download (model , tqdm_class = TqdmProgress , ** models [model ], max_workers = workers , cache_dir = cache )
202
196
nc .set_init_status (100 )
0 commit comments