Skip to content

Commit a67f12e

Browse files
Update (base update)
[ghstack-poisoned]
2 parents d75cc14 + e03f777 commit a67f12e

66 files changed

Lines changed: 2801 additions & 689 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/pull.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,33 @@ jobs:
816816
# Test test_arm_backend.sh with test
817817
backends/arm/test/test_arm_backend.sh "${ARM_TEST}"
818818
819+
test-arm-backend-public-api-backward-compatibility:
820+
name: test-arm-backend-public-api-backward-compatibility
821+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
822+
permissions:
823+
id-token: write
824+
contents: read
825+
with:
826+
runner: linux.2xlarge.memory
827+
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk
828+
submodules: 'recursive'
829+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
830+
timeout: 120
831+
script: |
832+
# The generic Linux job chooses to use base env, not the one setup by the image
833+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
834+
conda activate "${CONDA_ENV}"
835+
836+
source .ci/scripts/utils.sh
837+
install_executorch "--use-pt-pinned-commit"
838+
839+
.ci/scripts/setup-arm-baremetal-tools.sh --enable-mlsdk-deps --install-mlsdk-deps-with-pip
840+
source examples/arm/arm-scratch/setup_path.sh
841+
842+
backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh
843+
844+
python backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py
845+
819846
test-llama-runner-qnn-linux:
820847
name: test-llama-runner-qnn-linux
821848
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/arm/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ runtime.python_library(
119119
"//executorch/exir:lib",
120120
],
121121
)
122+
runtime.python_library(
123+
name = "public_api",
124+
srcs = ["__init__.py"],
125+
deps = [
126+
":ethosu",
127+
":vgf",
128+
"//executorch/backends/arm/quantizer:lib",
129+
],
130+
)
122131

123132
runtime.python_library(
124133
name = "process_node",

backends/arm/ao_ext/ops/mxfp_conv2d_op.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,12 @@ def __init__(
206206
padding: tuple[int, int],
207207
dilation: tuple[int, int],
208208
groups: int,
209-
config: MXFPOpConfig,
209+
weight_dtype: MXFPDType,
210+
block_size: int,
210211
) -> None:
211212
super().__init__()
212-
self.config = config
213-
self.weight_dtype = mxfp_dtype_to_str(config.weight_dtype)
213+
self.weight_dtype = mxfp_dtype_to_str(weight_dtype)
214+
self.block_size = block_size
214215

215216
self.register_buffer("weight_qdata", weight_qdata, persistent=True)
216217
self.register_buffer("weight_scale", weight_scale, persistent=True)
@@ -241,7 +242,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
241242
list(self.padding),
242243
list(self.dilation),
243244
self.groups,
244-
self.config.block_size,
245+
self.block_size,
245246
self.weight_dtype,
246247
)
247248

@@ -283,5 +284,6 @@ def transform_conv2d_to_mxfp(
283284
padding,
284285
dilation,
285286
module.groups,
286-
config,
287+
config.weight_dtype,
288+
config.block_size,
287289
)

backends/arm/ao_ext/ops/mxfp_linear_op.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
)
3434

3535

36+
_SUPPORTED_OUTPUT_DTYPES: set[torch.dtype] = {
37+
torch.float32,
38+
torch.bfloat16,
39+
}
40+
41+
3642
def _get_mx_elem_dtype(
3743
weight_qdata: torch.Tensor,
3844
weight_payload_dtype: str = "",
@@ -137,11 +143,14 @@ def __init__(
137143
weight_qdata: torch.Tensor,
138144
weight_scale: torch.Tensor,
139145
bias: torch.Tensor | None,
140-
config: MXFPOpConfig,
146+
weight_dtype: MXFPDType,
147+
block_size: int,
148+
output_dtype: torch.dtype = torch.float32,
141149
) -> None:
142150
super().__init__()
143-
self.config = config
144-
self.weight_dtype = mxfp_dtype_to_str(config.weight_dtype)
151+
self.weight_dtype = mxfp_dtype_to_str(weight_dtype)
152+
self.block_size = block_size
153+
self.output_dtype = output_dtype
145154

146155
self.register_buffer("weight_qdata", weight_qdata, persistent=True)
147156
self.register_buffer("weight_scale", weight_scale, persistent=True)
@@ -158,14 +167,17 @@ def __init__(
158167
)
159168

160169
def forward(self, x: torch.Tensor) -> torch.Tensor:
161-
return torch.ops.tosa_mxfp.linear.default(
170+
output = torch.ops.tosa_mxfp.linear.default(
162171
x,
163172
self.weight_qdata,
164173
self.weight_scale,
165174
self.bias,
166-
self.config.block_size,
175+
self.block_size,
167176
self.weight_dtype,
168177
)
178+
if self.output_dtype != torch.float32:
179+
output = output.to(self.output_dtype)
180+
return output
169181

170182

171183
def transform_linear_to_mxfp(
@@ -195,4 +207,14 @@ def transform_linear_to_mxfp(
195207
weight_scale = weight_scale.unsqueeze(0)
196208

197209
bias = module.bias.detach().to(torch.float32) if module.bias is not None else None
198-
return MXFPLinearOp(weight_qdata, weight_scale, bias, config)
210+
output_dtype = weight.dtype
211+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
212+
raise ValueError(f"Unsupported output_dtype: {output_dtype}")
213+
return MXFPLinearOp(
214+
weight_qdata,
215+
weight_scale,
216+
bias,
217+
config.weight_dtype,
218+
config.block_size,
219+
output_dtype,
220+
)

backends/arm/operators/op_tosa_identity.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,28 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Any, List
7-
8-
import torch
96
import tosa_serializer as ts
107

11-
from executorch.backends.arm.operators.node_visitor import (
12-
NodeVisitor,
13-
register_node_visitor,
14-
)
15-
from executorch.backends.arm.operators.operator_validation_utils import (
16-
validate_num_inputs,
17-
validate_same_dtype,
18-
validate_valid_dtype,
8+
from executorch.backends.arm.operators.node_visitor import register_node_visitor
9+
from executorch.backends.arm.operators.simple_node_visitor import (
10+
SimpleNodeVisitor,
11+
SimpleNodeVisitorConfig,
1912
)
20-
from executorch.backends.arm.tosa.mapping import TosaArg
2113

2214

2315
@register_node_visitor
24-
class IdentityVisitor(NodeVisitor):
16+
class IdentityVisitor(SimpleNodeVisitor):
2517
"""Lower the TOSA IDENTITY op."""
2618

2719
target = "tosa.IDENTITY.default"
2820

29-
def define_node(
30-
self,
31-
node: torch.fx.Node,
32-
tosa_graph: Any,
33-
inputs: List[TosaArg],
34-
output: TosaArg,
35-
) -> None:
36-
validate_num_inputs(self.target, inputs, 1)
37-
validate_same_dtype(self.target, [inputs[0], output], ts)
38-
validate_valid_dtype(
39-
self.target,
40-
[inputs[0], output],
41-
[
21+
@classmethod
22+
def get_config(cls) -> SimpleNodeVisitorConfig:
23+
return SimpleNodeVisitorConfig(
24+
tosa_op=ts.Op.IDENTITY,
25+
attr_method="IdentityAttribute",
26+
num_inputs=1,
27+
input_dtypes=[
4228
ts.DType.BOOL,
4329
ts.DType.INT8,
4430
ts.DType.INT16,
@@ -49,16 +35,4 @@ def define_node(
4935
ts.DType.FP8E4M3,
5036
ts.DType.FP8E5M2,
5137
],
52-
self.tosa_spec,
53-
)
54-
55-
attr = ts.TosaSerializerAttribute()
56-
attr.IdentityAttribute()
57-
self._serialize_operator(
58-
node,
59-
tosa_graph,
60-
ts.Op.IDENTITY,
61-
[inputs[0].name],
62-
[output.name],
63-
attr,
6438
)

backends/arm/public_api_manifests/api_manifest_running.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
#
66
# This file is generated by
7-
# backends/arm/scripts/generate_public_api_manifest.py
7+
# backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py
88

99
[python]
1010

backends/arm/runtime/VGFSetup.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,73 @@ static bool find_memory_index(
778778
memory_type_out);
779779
}
780780

781+
bool VgfRepr::map_persistent_io_memory() {
782+
unmap_persistent_io_memory();
783+
784+
for (auto& io : IOs) {
785+
if (io.memory == VK_NULL_HANDLE) {
786+
ET_LOG(Error, "Cannot persistently map null Vulkan IO memory");
787+
unmap_persistent_io_memory();
788+
return false;
789+
}
790+
791+
void* persistent_memory = nullptr;
792+
793+
// IO resources may alias the same VkDeviceMemory. Vulkan memory must not be
794+
// mapped more than once at the same time, so map each unique memory once
795+
// and share the returned pointer across aliased IO entries.
796+
// Make sure that memory is HOST_VISIBLE and HOST_COHERENT.
797+
bool found_existing_mapping = false;
798+
auto mapped_memory_it = std::find_if(
799+
persistent_mapped_memories.begin(),
800+
persistent_mapped_memories.end(),
801+
[&](const auto& mapped_memory) {
802+
return mapped_memory.memory == io.memory;
803+
});
804+
805+
if (mapped_memory_it != persistent_mapped_memories.end()) {
806+
persistent_memory = mapped_memory_it->data;
807+
found_existing_mapping = true;
808+
}
809+
810+
if (!found_existing_mapping) {
811+
VkResult result = vkMapMemory(
812+
vk_device, io.memory, 0, VK_WHOLE_SIZE, 0, &persistent_memory);
813+
if (result != VK_SUCCESS) {
814+
ET_LOG(
815+
Error,
816+
"Failed to persistently map Vulkan IO memory, error %d",
817+
result);
818+
unmap_persistent_io_memory();
819+
return false;
820+
}
821+
822+
persistent_mapped_memories.push_back(PersistentMappedMemory{
823+
.memory = io.memory,
824+
.data = persistent_memory,
825+
});
826+
}
827+
828+
io.persistent_memory = persistent_memory;
829+
}
830+
831+
return true;
832+
}
833+
834+
void VgfRepr::unmap_persistent_io_memory() {
835+
for (const auto& mapped_memory : persistent_mapped_memories) {
836+
if (mapped_memory.memory != VK_NULL_HANDLE &&
837+
mapped_memory.data != nullptr) {
838+
vkUnmapMemory(vk_device, mapped_memory.memory);
839+
}
840+
}
841+
persistent_mapped_memories.clear();
842+
843+
for (auto& io : IOs) {
844+
io.persistent_memory = nullptr;
845+
}
846+
}
847+
781848
VkResult allocate_memory(
782849
VkPhysicalDevice physical,
783850
VkDevice device,
@@ -1839,6 +1906,7 @@ bool VgfRepr::process_vgf(
18391906
VK_NULL_HANDLE,
18401907
tensor_memory,
18411908
{0, 0, 0},
1909+
nullptr,
18421910
owns_memory,
18431911
true,
18441912
is_in});
@@ -1931,6 +1999,7 @@ bool VgfRepr::process_vgf(
19311999
VK_NULL_HANDLE,
19322000
buffer_memory,
19332001
{0, 0, 0},
2002+
nullptr,
19342003
owns_memory,
19352004
true,
19362005
is_in});
@@ -2117,6 +2186,7 @@ bool VgfRepr::process_vgf(
21172186
image_memory,
21182187
staging_memory,
21192188
image_extent,
2189+
nullptr,
21202190
true,
21212191
owns_image_memory,
21222192
is_in});
@@ -3433,6 +3503,15 @@ bool VgfRepr::process_vgf(
34333503
vkEndCommandBuffer(vk_execute_cmd);
34343504
}
34353505

3506+
{
3507+
VGF_PROFILE_SCOPE(event_tracer, "VGF_INIT_MAP_IO_MEMORY");
3508+
3509+
if (!map_persistent_io_memory()) {
3510+
ET_LOG(Error, "Failed to persistently map VGF IO memory");
3511+
return false;
3512+
}
3513+
}
3514+
34363515
return true;
34373516
}
34383517

@@ -3493,6 +3572,8 @@ bool VgfRepr::execute_vgf(executorch::runtime::EventTracer* event_tracer) {
34933572
}
34943573

34953574
void VgfRepr::free_vgf() {
3575+
unmap_persistent_io_memory();
3576+
34963577
if (vk_timestamp_query_pool != VK_NULL_HANDLE) {
34973578
vkDestroyQueryPool(vk_device, vk_timestamp_query_pool, nullptr);
34983579
vk_timestamp_query_pool = VK_NULL_HANDLE;

0 commit comments

Comments
 (0)