Skip to content

WebGPU fixes #1293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
const session_config = {
dtype: selectedDtype,
kv_cache_dtype,
device: selectedDevice,
}

// Construct the model file name
Expand Down Expand Up @@ -417,6 +418,10 @@ function validateInputs(session, inputs) {
return checkedInputs;
}

// Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
// For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
let webInferenceChain = Promise.resolve();

/**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
Expand All @@ -433,17 +438,28 @@ async function sessionRun(session, inputs) {
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
output = replaceTensors(output);
return output;
const run = () => session.run(ortFeed);
const output = await ((apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV)
? (webInferenceChain = webInferenceChain.then(run))
: run());
return replaceTensors(output);
} catch (e) {
// Error messages can be long (nested) and uninformative. For this reason,
// we apply minor formatting to show the most important information
const formatted = Object.fromEntries(Object.entries(checkedInputs)
.map(([k, { type, dims, data }]) => [k, {
.map(([k, tensor]) => {
// Extract these properties from the underlying ORT tensor
type, dims, data,
}]));
const unpacked = {
type: tensor.type,
dims: tensor.dims,
location: tensor.location,
}
if (unpacked.location !== "gpu-buffer") {
// Only return the data if it's not a GPU buffer
unpacked.data = tensor.data;
}
Comment on lines +458 to +460
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed, otherwise another error is thrown because data is on GPU, not CPU, so we can't access it with .data without awaiting it's contents (not needed for an error message)

return [k, unpacked];
}));

// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
Expand Down Expand Up @@ -5223,7 +5239,7 @@ export class RTDetrV2ForObjectDetection extends RTDetrV2PreTrainedModel {
}
}

export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput {}
export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand All @@ -5238,7 +5254,7 @@ export class RFDetrForObjectDetection extends RFDetrPreTrainedModel {
}
}

export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput {}
export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down
61 changes: 61 additions & 0 deletions tests/utils/generation.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,67 @@ describe("PKV caching", () => {
}, MAX_MODEL_DISPOSE_TIME);
});

describe("LlamaForCausalLM (onnxruntime-genai)", () => {
const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-GQA";
/** @type {LlamaForCausalLM} */
let model;
/** @type {LlamaTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await LlamaForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
tokenizer = await LlamaTokenizer.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("1");

// Generate first sequence w/o PKV
// NOTE: `return_dict_in_generate=true` is required to get PKV
const { past_key_values, sequences } = await model.generate({
...inputs,
max_new_tokens: 5,
do_sample: false,
return_dict_in_generate: true,
});

// Update output with new text
const decoded = tokenizer.batch_decode(sequences, {
skip_special_tokens: false,
})[0];
const new_inputs = tokenizer(decoded + "2", {
add_special_tokens: false,
});

// Run w/o PKV
const generated_ids = await model.generate({
...new_inputs,
max_new_tokens: 3,
do_sample: false,
});

// Run w/ PKV
const generated_ids_pkv = await model.generate({
...new_inputs,
past_key_values,
max_new_tokens: 3,
do_sample: false,
});

const target = [[128000n, 16n, 34732n, 98805n, 116404n, 68265n, 99392n, 17n, 21855n, 60933n, 14285n]];

expect(generated_ids.tolist()).toEqual(target);
expect(generated_ids_pkv.tolist()).toEqual(target);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("LlavaForConditionalGeneration", () => {
const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration";
/** @type {LlavaForConditionalGeneration} */
Expand Down