Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions docs/source/llm/run-on-android.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,26 @@ To add the `executorch-android` library to your app, see [Using ExecuTorch on An

## Runtime API

Once the `executorch-android` AAR is on your classpath, you can import the LLM runner classes from the `org.pytorch.executorch.extension.llm` package.
Once the `executorch-android` AAR is on your classpath, you can import the LLM runner classes from the `org.pytorch.executorch.extension.llm` package. The runner is callable from both Java and Kotlin; the rest of this guide shows both side by side.

### Importing

Java:
```java
import org.pytorch.executorch.extension.llm.LlmModule;
import org.pytorch.executorch.extension.llm.LlmModuleConfig;
import org.pytorch.executorch.extension.llm.LlmGenerationConfig;
import org.pytorch.executorch.extension.llm.LlmCallback;
```

Kotlin:
```kotlin
import org.pytorch.executorch.extension.llm.LlmModule
import org.pytorch.executorch.extension.llm.LlmModuleConfig
import org.pytorch.executorch.extension.llm.LlmGenerationConfig
import org.pytorch.executorch.extension.llm.LlmCallback
```

### LlmModule
Comment thread
omkar-334 marked this conversation as resolved.

The `LlmModule` class provides a simple Java interface for loading a text-generation model, configuring its tokenizer, generating token streams, and stopping execution. It also supports multimodal models that accept image and audio inputs alongside a text prompt.
Expand All @@ -31,15 +40,26 @@ This API is experimental and subject to change.

Create an `LlmModule` by specifying paths to your serialized model (`.pte`) and tokenizer files. For text-only models, the simple constructor is enough:

Java:
```java
LlmModule module = new LlmModule(
"/data/local/tmp/llama-3.2-instruct.pte",
"/data/local/tmp/tokenizer.model",
0.8f);
```

Kotlin:
```kotlin
val module = LlmModule(
"/data/local/tmp/llama-3.2-instruct.pte",
"/data/local/tmp/tokenizer.model",
0.8f,
)
Comment thread
omkar-334 marked this conversation as resolved.
```

For finer control (multimodal model type, BOS/EOS handling, supplementary data files, load mode), use `LlmModuleConfig` with the fluent builder:

Java:
```java
LlmModuleConfig config = LlmModuleConfig.create()
.modulePath("/data/local/tmp/llama-3.2-instruct.pte")
Expand All @@ -52,27 +72,50 @@ LlmModuleConfig config = LlmModuleConfig.create()
LlmModule module = new LlmModule(config);
```

Available load modes are `LOAD_MODE_FILE`, `LOAD_MODE_MMAP` (default), `LOAD_MODE_MMAP_USE_MLOCK`, and `LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS`. Available model types are `MODEL_TYPE_TEXT`, `MODEL_TYPE_TEXT_VISION`, and `MODEL_TYPE_MULTIMODAL`.
Kotlin:
```kotlin
val config = LlmModuleConfig.create()
.modulePath("/data/local/tmp/llama-3.2-instruct.pte")
.tokenizerPath("/data/local/tmp/tokenizer.model")
.temperature(0.8f)
.modelType(LlmModuleConfig.MODEL_TYPE_TEXT)
.loadMode(LlmModuleConfig.LOAD_MODE_MMAP)
.build()

val module = LlmModule(config)
```

Available load modes are `LOAD_MODE_FILE`, `LOAD_MODE_MMAP` (default), `LOAD_MODE_MMAP_USE_MLOCK`, and `LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS`. Available model types are `MODEL_TYPE_TEXT` and `MODEL_TYPE_TEXT_VISION` (the `MODEL_TYPE_MULTIMODAL` constant is currently an alias for `MODEL_TYPE_TEXT_VISION` and selects the same runtime path).

Construction itself is lightweight and does not load the program data immediately.

#### Loading

Explicitly load the model before generation to avoid paying the load cost during your first `generate` call.

Java:
```java
int status = module.load();
if (status != 0) {
// Handle load failure (status is an ExecuTorch runtime error code).
}
Comment on lines 106 to 109
```

Kotlin:
```kotlin
val status = module.load()
if (status != 0) {
// Handle load failure (status is an ExecuTorch runtime error code).
}
Comment on lines +114 to +117
```

If you skip this step, the model is loaded lazily on the first `generate` call.

#### Generating

Generate tokens from a text prompt by passing an `LlmCallback` that receives each token as it is produced. The same callback also receives a JSON-encoded statistics string when generation completes.

Java:
```java
LlmCallback callback = new LlmCallback() {
@Override
Expand All @@ -97,8 +140,31 @@ LlmCallback callback = new LlmCallback() {
module.generate("Once upon a time", callback);
```

Kotlin:
```kotlin
val callback = object : LlmCallback {
override fun onResult(token: String) {
// Called once per generated token. Append to your UI buffer here.
print(token)
}

override fun onStats(statsJson: String) {
// Called once when generation finishes. See extension/llm/runner/stats.h
// for the field definitions.
println("\n$statsJson")
}

override fun onError(errorCode: Int, message: String) {
// Called if the runtime reports an error during generation.
}
}

module.generate("Once upon a time", callback)
```

For full control over generation parameters, use `LlmGenerationConfig`:

Java:
```java
LlmGenerationConfig genConfig = LlmGenerationConfig.create()
.seqLen(2048)
Expand All @@ -109,26 +175,49 @@ LlmGenerationConfig genConfig = LlmGenerationConfig.create()
module.generate("Once upon a time", genConfig, callback);
```

Kotlin:
```kotlin
val genConfig = LlmGenerationConfig.create()
.seqLen(2048)
.temperature(0.8f)
.echo(false)
.build()

module.generate("Once upon a time", genConfig, callback)
```

`LlmGenerationConfig` exposes `echo`, `maxNewTokens`, `seqLen`, `temperature`, `numBos`, `numEos`, and `warming`. Defaults match the C++ `GenerationConfig` documented in [Running LLMs with C++](run-with-c-plus-plus.md).

#### Stopping Generation

If you need to interrupt a long-running generation, call `stop()` from another thread (or from inside the `onResult` callback):

Java:
```java
module.stop();
```

Kotlin:
```kotlin
module.stop()
```

Generation also runs synchronously on the calling thread, so make sure you invoke `generate()` off the main thread (for example, on a `HandlerThread` or via a `java.util.concurrent.Executor`).

#### Resetting

To clear the prefilled tokens from the KV cache and reset the start position to 0, call:

Java:
```java
module.resetContext();
```

Kotlin:
```kotlin
module.resetContext()
```

This is the equivalent of `reset()` on the iOS runner and `reset()` on the C++ `IRunner`.

### Multimodal Inputs
Expand All @@ -139,6 +228,7 @@ For models declared as `MODEL_TYPE_TEXT_VISION` or `MODEL_TYPE_MULTIMODAL`, imag

Comment thread
omkar-334 marked this conversation as resolved.
Raw uint8 pixel data in CHW order can be supplied as an `int[]`, or as a direct `ByteBuffer` to avoid JNI array copies:

Java:
```java
// As int[]
int[] pixels = ...; // length == channels * height * width
Expand All @@ -150,8 +240,23 @@ buffer.put(rawBytes).rewind();
module.prefillImages(buffer, 336, 336, 3);
Comment on lines +246 to 251
```

Kotlin:
```kotlin
// As IntArray
val pixels: IntArray = ... // length == channels * height * width
module.prefillImages(pixels, /* width = */ 336, /* height = */ 336, /* channels = */ 3)

// As direct ByteBuffer (preferred for large images)
val buffer = ByteBuffer.allocateDirect(3 * 336 * 336).apply {
put(rawBytes)
rewind()
}
Comment thread
omkar-334 marked this conversation as resolved.
module.prefillImages(buffer, 336, 336, 3)
Comment thread
omkar-334 marked this conversation as resolved.
```

Pre-normalized float pixel data is also supported, both as a `float[]` and as a direct `ByteBuffer` in native byte order:

Java:
```java
float[] normalized = ...; // length == channels * height * width
module.prefillImages(normalized, 336, 336, 3);
Expand All @@ -163,31 +268,63 @@ ByteBuffer floatBuffer = ByteBuffer
module.prefillNormalizedImage(floatBuffer, 336, 336, 3);
Comment on lines 273 to 281
```

Kotlin:
```kotlin
val normalized: FloatArray = ... // length == channels * height * width
module.prefillImages(normalized, 336, 336, 3)

val floatBuffer: ByteBuffer = ByteBuffer
.allocateDirect(3 * 336 * 336 * Float.SIZE_BYTES)
.order(ByteOrder.nativeOrder())
Comment on lines +289 to +291

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to add import java.nio.ByteBuffer and import java.nio.ByteOrder stmt ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @omkar-334, thanks for filing the PR as a follow-up to our conversation during the docathon. Could you address this review?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @nil-is-all , yes i've addressed this review, along with other copilot reviews. Thanks!

// fill floatBuffer with normalized values, then:
module.prefillNormalizedImage(floatBuffer, 336, 336, 3)
Comment thread
omkar-334 marked this conversation as resolved.
```

#### Audio

Preprocessed audio features (for example mel spectrograms produced by a Whisper preprocessor) can be supplied as `byte[]` or `float[]`:

Java:
```java
module.prefillAudio(features, /*batchSize=*/1, /*nBins=*/128, /*nFrames=*/3000);
```

Kotlin:
```kotlin
module.prefillAudio(features, /* batchSize = */ 1, /* nBins = */ 128, /* nFrames = */ 3000)
```

Raw audio samples can be supplied with `prefillRawAudio`:

Java:
```java
module.prefillRawAudio(samples, /*batchSize=*/1, /*nChannels=*/1, /*nSamples=*/16000);
```

Kotlin:
```kotlin
module.prefillRawAudio(samples, /* batchSize = */ 1, /* nChannels = */ 1, /* nSamples = */ 16000)
```

#### Generating with Multimodal Prefill

After prefilling each modality, run `generate()` with the text prompt as usual:

Java:
```java
module.prefillImages(pixels, 336, 336, 3);
module.generate("What's in this image?", callback);
```

Kotlin:
```kotlin
module.prefillImages(pixels, 336, 336, 3)
module.generate("What's in this image?", callback)
```

For text-vision models, a convenience overload accepts the image and prompt together:

Java:
```java
module.generate(
pixels, /*width=*/336, /*height=*/336, /*channels=*/3,
Expand All @@ -197,6 +334,17 @@ module.generate(
/*echo=*/false);
```

Kotlin:
```kotlin
module.generate(
pixels, /* width = */ 336, /* height = */ 336, /* channels = */ 3,
"What's in this image?",
/* seqLen = */ 768,
callback,
/* echo = */ false,
)
```

## Demo

See the [Llama Android demo app](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android/LlamaDemo) in `executorch-examples` for an end-to-end project that wires `LlmModule`, `LlmCallback`, and a `HandlerThread` into a chat UI.
Loading