Skip to content

Commit 89f5f70

Browse files
Arm backend: Added support for example models.
Added missing enablement of model definitions within examples/models for model testing. Change-Id: I4fc8cb7377821b1ba2c262da692e016713418f70 Signed-off-by: Michiel Olieslagers <michiel.olieslagers@arm.com>
1 parent 21a0191 commit 89f5f70

6 files changed

Lines changed: 100 additions & 0 deletions

File tree

examples/models/BUCK

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ fbcode_target(_kind = python_library,
3333
"//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual
3434
"//executorch/examples/models/smollm2:smollm2", # @manual
3535
"//executorch/examples/models/smollm3:smollm3", # @manual
36+
"//executorch/examples/models/smolvlm:smolvlm", # @manual
37+
"//executorch/examples/models/whisper:whisper", # @manual
38+
"//executorch/examples/models/yolo26:yolo26", # @manual
3639
],
3740
)
3841

examples/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class Model(str, Enum):
4545
MobileNetV1025 = "mobilenet_v1_025"
4646
ResNet8 = "resnet8"
4747
Sdpa = "sdpa"
48+
Qwen3 = "qwen3"
49+
SmolVLM = "smolvlm"
50+
YOLO26 = "yolo26"
51+
Whisper = "whisper"
4852

4953
def __str__(self) -> str:
5054
return self.value
@@ -105,6 +109,10 @@ def __str__(self) -> str:
105109
),
106110
str(Model.ResNet8): ("mlperf_tiny.resnet8", "ResNet8Model"),
107111
str(Model.Sdpa): ("toy_model", "SdpaModule"),
112+
str(Model.Qwen3): ("qwen3", "Qwen3Model"),
113+
str(Model.SmolVLM): ("smolvlm", "SmolVLMModel"),
114+
str(Model.YOLO26): ("yolo26", "YOLO26Model"),
115+
str(Model.Whisper): ("whisper", "WhisperModel"),
108116
}
109117

110118
__all__ = [
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .model import WhisperModel
7+
8+
__all__ = [WhisperModel]

examples/models/whisper/model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch
9+
from transformers import WhisperModel as wm
10+
11+
from ..model_base import EagerModelBase
12+
13+
14+
class WhisperWrapper(torch.nn.Module):
15+
def __init__(self, model):
16+
super().__init__()
17+
self.model = model
18+
19+
def forward(self, input_features, decoder_input_ids):
20+
return self.model(
21+
input_features=input_features,
22+
decoder_input_ids=decoder_input_ids,
23+
use_cache=False,
24+
).last_hidden_state
25+
26+
27+
class WhisperModel(EagerModelBase):
28+
def __init__(self, model_name: str = "openai/whisper-base"):
29+
self.model = wm.from_pretrained(model_name)
30+
31+
def get_eager_model(self) -> torch.nn.Module:
32+
logging.info("Loading Whisper model")
33+
m = WhisperWrapper(self.model).eval()
34+
logging.info("Loaded Whisper model")
35+
return m
36+
37+
def get_example_inputs(self):
38+
return (
39+
torch.randn(1, 80, 3000),
40+
torch.tensor(
41+
[[self.model.config.decoder_start_token_id]], dtype=torch.long
42+
),
43+
)

examples/models/yolo26/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .model import YOLO26Model
7+
8+
__all__ = [YOLO26Model]

examples/models/yolo26/model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch
9+
10+
from ultralytics import YOLO
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class YOLO26Model(EagerModelBase):
16+
def __init__(self, model_name: str = "yolo26n.pt"):
17+
self.model_name = model_name
18+
self.yolo = YOLO(model_name)
19+
self.dummy_frame = torch.randn((320, 320, 3)).to(torch.uint8).numpy()
20+
self.yolo.predict(self.dummy_frame, imgsz=(320, 320), verbose=False)
21+
self.model = self.yolo.model.eval()
22+
23+
def get_eager_model(self) -> torch.nn.Module:
24+
logging.info("Loading " + self.model_name + " model")
25+
m = self.model
26+
logging.info("Loaded " + self.model_name + " model")
27+
return m
28+
29+
def get_example_inputs(self):
30+
return (self.yolo.predictor.preprocess([self.dummy_frame]),)

0 commit comments

Comments
 (0)