diff --git a/src/models.js b/src/models.js index 6c0dfb73f..9d5d4d12c 100644 --- a/src/models.js +++ b/src/models.js @@ -237,6 +237,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { const session_config = { dtype: selectedDtype, kv_cache_dtype, + device: selectedDevice, } // Construct the model file name @@ -417,6 +418,10 @@ function validateInputs(session, inputs) { return checkedInputs; } +// Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU. +// For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started"). +let webInferenceChain = Promise.resolve(); + /** * Executes an InferenceSession using the specified inputs. * NOTE: `inputs` must contain at least the input names of the model. @@ -433,17 +438,28 @@ async function sessionRun(session, inputs) { try { // pass the original ort tensor const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); - let output = await session.run(ortFeed); - output = replaceTensors(output); - return output; + const run = () => session.run(ortFeed); + const output = await ((apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV) + ? (webInferenceChain = webInferenceChain.then(run)) + : run()); + return replaceTensors(output); } catch (e) { // Error messages can be long (nested) and uninformative. For this reason, // we apply minor formatting to show the most important information const formatted = Object.fromEntries(Object.entries(checkedInputs) - .map(([k, { type, dims, data }]) => [k, { + .map(([k, tensor]) => { // Extract these properties from the underlying ORT tensor - type, dims, data, - }])); + const unpacked = { + type: tensor.type, + dims: tensor.dims, + location: tensor.location, + } + if (unpacked.location !== "gpu-buffer") { + // Only return the data if it's not a GPU buffer + unpacked.data = tensor.data; + } + return [k, unpacked]; + })); // This usually occurs when the inputs are of the wrong type. console.error(`An error occurred during model execution: "${e}".`); @@ -5223,7 +5239,7 @@ export class RTDetrV2ForObjectDetection extends RTDetrV2PreTrainedModel { } } -export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput {} +export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -5238,7 +5254,7 @@ export class RFDetrForObjectDetection extends RFDetrPreTrainedModel { } } -export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput {} +export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/tests/utils/generation.test.js b/tests/utils/generation.test.js index 2eab931c4..377816ff3 100644 --- a/tests/utils/generation.test.js +++ b/tests/utils/generation.test.js @@ -282,6 +282,67 @@ describe("PKV caching", () => { }, MAX_MODEL_DISPOSE_TIME); }); + describe("LlamaForCausalLM (onnxruntime-genai)", () => { + const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-GQA"; + /** @type {LlamaForCausalLM} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await LlamaForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("1"); + + // Generate first sequence w/o PKV + // NOTE: `return_dict_in_generate=true` is required to get PKV + const { past_key_values, sequences } = await model.generate({ + ...inputs, + max_new_tokens: 5, + do_sample: false, + return_dict_in_generate: true, + }); + + // Update output with new text + const decoded = tokenizer.batch_decode(sequences, { + skip_special_tokens: false, + })[0]; + const new_inputs = tokenizer(decoded + "2", { + add_special_tokens: false, + }); + + // Run w/o PKV + const generated_ids = await model.generate({ + ...new_inputs, + max_new_tokens: 3, + do_sample: false, + }); + + // Run w/ PKV + const generated_ids_pkv = await model.generate({ + ...new_inputs, + past_key_values, + max_new_tokens: 3, + do_sample: false, + }); + + const target = [[128000n, 16n, 34732n, 98805n, 116404n, 68265n, 99392n, 17n, 21855n, 60933n, 14285n]]; + + expect(generated_ids.tolist()).toEqual(target); + expect(generated_ids_pkv.tolist()).toEqual(target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + describe("LlavaForConditionalGeneration", () => { const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration"; /** @type {LlavaForConditionalGeneration} */