diff --git a/src/configs.js b/src/configs.js index 4303169f3..d95661a84 100644 --- a/src/configs.js +++ b/src/configs.js @@ -367,6 +367,7 @@ export class PretrainedConfig { cache_dir = null, local_files_only = false, revision = 'main', + request_options = {}, } = {}) { if (config && !(config instanceof PretrainedConfig)) { config = new PretrainedConfig(config); @@ -378,6 +379,7 @@ export class PretrainedConfig { cache_dir, local_files_only, revision, + request_options }) return new this(data); } diff --git a/src/env.js b/src/env.js index f351c47f8..04d6c6209 100644 --- a/src/env.js +++ b/src/env.js @@ -118,6 +118,7 @@ const localModelPath = RUNNING_LOCALLY * @property {string} cacheDir The directory to use for caching files with the file system. By default, it is `./.cache`. * @property {boolean} useCustomCache Whether to use a custom cache system (defined by `customCache`), defaults to `false`. * @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which + * @property {(input: RequestInfo | URL, init?: RequestInit) => Promise} customFetch A custom fetch function to use. Defaults to `null`. Note: this must be a function which * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache */ @@ -150,6 +151,10 @@ export const env = { useCustomCache: false, customCache: null, ////////////////////////////////////////////////////// + + /////////////////// custom settings /////////////////// + customFetch: null, + ////////////////////////////////////////////////////// } diff --git a/src/models.js b/src/models.js index 204fca4eb..18b7c7dbc 100644 --- a/src/models.js +++ b/src/models.js @@ -985,6 +985,7 @@ export class PreTrainedModel extends Callable { dtype = null, use_external_data_format = null, session_options = {}, + request_options = {} } = {}) { let options = { @@ -999,6 +1000,7 @@ export class PreTrainedModel extends Callable { dtype, use_external_data_format, session_options, + request_options } const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this); @@ -6999,6 +7001,7 @@ export class PretrainedMixin { dtype = null, use_external_data_format = null, session_options = {}, + request_options = {} } = {}) { const options = { @@ -7013,6 +7016,7 @@ export class PretrainedMixin { dtype, use_external_data_format, session_options, + request_options, } options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); diff --git a/src/pipelines.js b/src/pipelines.js index 649b00a49..90d78dc0c 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3301,6 +3301,7 @@ export async function pipeline( dtype = null, model_file_name = null, session_options = {}, + request_options = {} } = {} ) { // Helper method to construct pipeline @@ -3331,6 +3332,7 @@ export async function pipeline( dtype, model_file_name, session_options, + request_options, } const classes = new Map([ diff --git a/src/tokenizers.js b/src/tokenizers.js index 42d8dc04c..b9e918ea1 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2682,6 +2682,7 @@ export class PreTrainedTokenizer extends Callable { local_files_only = false, revision = 'main', legacy = null, + request_options = {}, } = {}) { const info = await loadTokenizer(pretrained_model_name_or_path, { @@ -2691,6 +2692,7 @@ export class PreTrainedTokenizer extends Callable { local_files_only, revision, legacy, + request_options, }) // @ts-ignore @@ -4351,6 +4353,7 @@ export class AutoTokenizer { local_files_only = false, revision = 'main', legacy = null, + request_options = {} } = {}) { const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { @@ -4360,6 +4363,7 @@ export class AutoTokenizer { local_files_only, revision, legacy, + request_options }) // Some tokenizers are saved with the "Fast" suffix, so we remove that if present. diff --git a/src/utils/hub.js b/src/utils/hub.js index 17ee4c1b1..ea7d41c7e 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -20,6 +20,7 @@ import { dispatchCallback } from './core.js'; * @property {string} [cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. * @property {boolean} [local_files_only=false] Whether or not to only look at local files (e.g., not try downloading the model). * @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id, + * @property {RequestInit} [request_options] The options to use when making the request. * since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. * NOTE: This setting is ignored for local requests. */ @@ -185,18 +186,23 @@ function isValidUrl(string, protocols = null, validHosts = null) { * Helper function to get a file, using either the Fetch API or FileSystem API. * * @param {URL|string} urlOrPath The URL/path of the file to get. + * @param {RequestInit} [request_options] The options to use when making the request. * @returns {Promise} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API). */ -export async function getFile(urlOrPath) { +export async function getFile(urlOrPath, request_options) { + /** + * @type {Headers} The headers to use when making the request. + */ + let headers + if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) { return new FileResponse(urlOrPath); - } else if (typeof process !== 'undefined' && process?.release?.name === 'node') { const IS_CI = !!process.env?.TESTING_REMOTELY; const version = env.version; - const headers = new Headers(); + headers = new Headers(); headers.set('User-Agent', `transformers.js/${version}; is_ci/${IS_CI};`); // Check whether we are making a request to the Hugging Face Hub. @@ -210,13 +216,23 @@ export async function getFile(urlOrPath) { headers.set('Authorization', `Bearer ${token}`); } } - return fetch(urlOrPath, { headers }); } else { // Running in a browser-environment, so we use default headers // NOTE: We do not allow passing authorization headers in the browser, // since this would require exposing the token to the client. - return fetch(urlOrPath); } + + /** + * @type {(input: RequestInfo | URL, init?: RequestInit) => Promise} A custom fetch function to use. Defaults to `null`. Note: this must be a function which + */ + let resolvedFetch; + if (env.customFetch) { + resolvedFetch = env.customFetch; + } else { + resolvedFetch = fetch + } + + return resolvedFetch(urlOrPath, {headers, ...request_options}); } const ERROR_MAPPING = { @@ -447,7 +463,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti const isURL = isValidUrl(requestURL, ['http:', 'https:']); if (!isURL) { try { - response = await getFile(localPath); + response = await getFile(localPath, options.request_options); cacheKey = localPath; // Update the cache key to be the local path } catch (e) { // Something went wrong while trying to get the file locally. @@ -479,7 +495,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti } // File not found locally, so we try to download it from the remote server - response = await getFile(remoteURL); + response = await getFile(remoteURL, options.request_options); if (response.status !== 200) { return handleError(response.status, remoteURL, fatal); diff --git a/tests/utils/hub.test.js b/tests/utils/hub.test.js index 3ef3f41f7..1b86ec624 100644 --- a/tests/utils/hub.test.js +++ b/tests/utils/hub.test.js @@ -36,5 +36,35 @@ describe("Hub", () => { }, MAX_TEST_EXECUTION_TIME, ); + + it("should cancel model loading", async () => { + const controller = new AbortController(); + const signal = controller.signal; + setTimeout(() => controller.abort(), 10); + try { + await AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist", { ...DEFAULT_MODEL_OPTIONS, request_options: { signal } }) + } catch (error) { + expect(error.name).toBe("AbortError"); + } + }, MAX_TEST_EXECUTION_TIME + 1000); + + it("should cancel multiple model loading", async () => { + const controller = new AbortController(); + const signal = controller.signal; + setTimeout(() => controller.abort(), 10); + + try { + await AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist", { ...DEFAULT_MODEL_OPTIONS, request_options: { signal } }) + } catch (error) { + expect(error.name).toBe("AbortError"); + } + + try { + await AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist", { ...DEFAULT_MODEL_OPTIONS, request_options: { signal } }) + } catch (error) { + expect(error.name).toBe("AbortError"); + } + + }, MAX_TEST_EXECUTION_TIME + 1000); }); });