Skip to content

Commit bad2036

Browse files
fabianmcgjoker-eph
andauthored
[mlir][ptr] Extend ptr_add operation to support shaped operands (llvm#156374)
This patch extends `ptr_add` to work with shaped types with value semantics, both for the offsets and base. Concretely this patch makes the following changes: - Supports scalar-to-scalar, scalar-to-shaped, shaped-to-scalar, and shaped-to-shaped combinations - Adds InferTypeOpInterface for automatic result type deduction - Adds tests for LLVM IR translation with vector operands Example: ```mlir func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> { %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex> %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex> return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>> } ``` The motivation behind this patch is to lay the groundwork for enabling `triton` styled loads and stores, and their variants. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent d91a5c3 commit bad2036

File tree

8 files changed

+222
-43
lines changed

8 files changed

+222
-43
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
1919
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
2020
#include "mlir/IR/OpDefinition.h"
21+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/SideEffectInterfaces.h"
2223
#include "mlir/Interfaces/ViewLikeInterface.h"
2324

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
1313
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
1414
include "mlir/Dialect/Ptr/IR/PtrEnums.td"
1515
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
16+
include "mlir/Interfaces/InferTypeOpInterface.td"
1617
include "mlir/Interfaces/SideEffectInterfaces.td"
1718
include "mlir/Interfaces/ViewLikeInterface.td"
1819
include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
3435
/*descr=*/[{A shaped type with value semantics and rank.}],
3536
/*cppType=*/"::mlir::ShapedType">;
3637

37-
// A shaped pointer type with value semantics and rank.
38-
class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
38+
// A ptr-like type, either scalar or shaped type with value semantics.
39+
def Ptr_PtrLikeType :
40+
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
41+
42+
// An int-like type, either scalar or shaped type with value semantics.
43+
def Ptr_IntLikeType :AnyTypeOf<[
44+
Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
45+
AnySignlessIntegerOrIndex
46+
]>;
3947

4048
// A shaped value type of rank 1 of any element type.
4149
def Ptr_Any1DType :
@@ -167,41 +175,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
167175
}];
168176
}
169177

170-
//===----------------------------------------------------------------------===//
171-
// PtrAddOp
172-
//===----------------------------------------------------------------------===//
173-
174-
def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
175-
Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
176-
]> {
177-
let summary = "Pointer add operation";
178-
let description = [{
179-
The `ptr_add` operation adds an integer offset to a pointer to produce a new
180-
pointer. The input and output pointer types are always the same.
181-
182-
Example:
183-
184-
```mlir
185-
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
186-
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
187-
```
188-
}];
189-
190-
let arguments = (ins
191-
Ptr_PtrType:$base,
192-
AnySignlessIntegerOrIndex:$offset,
193-
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
194-
let results = (outs Ptr_PtrType:$result);
195-
let assemblyFormat = [{
196-
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
197-
}];
198-
let hasFolder = 1;
199-
let extraClassDeclaration = [{
200-
/// `ViewLikeOp::getViewSource` method.
201-
Value getViewSource() { return getBase(); }
202-
}];
203-
}
204-
205178
//===----------------------------------------------------------------------===//
206179
// LoadOp
207180
//===----------------------------------------------------------------------===//
@@ -361,6 +334,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
361334
let hasVerifier = 1;
362335
}
363336

337+
//===----------------------------------------------------------------------===//
338+
// PtrAddOp
339+
//===----------------------------------------------------------------------===//
340+
341+
def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
342+
Pure, ViewLikeOpInterface,
343+
DeclareOpInterfaceMethods<InferTypeOpInterface>
344+
]> {
345+
let summary = "Pointer add operation";
346+
let description = [{
347+
The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
348+
349+
The operation supports both scalar and shaped types with value semantics:
350+
- When both base and offset are scalar: produces a single new pointer
351+
- When base is shaped and offset is scalar: adds the same offset to each
352+
pointer in the base
353+
- When base is scalar and offset is shaped: adds the single pointer to each
354+
offset in the shaped value
355+
- When both are shaped: performs element-wise addition (shapes must be
356+
compatible)
357+
358+
Example:
359+
360+
```mlir
361+
// Scalar base and offset
362+
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
363+
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
364+
365+
// Shaped base with scalar offset
366+
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
367+
368+
// Scalar base with shaped offset
369+
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
370+
371+
// Both base and offset are shaped
372+
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
373+
```
374+
}];
375+
let arguments = (ins
376+
Ptr_PtrLikeType:$base,
377+
Ptr_IntLikeType:$offset,
378+
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
379+
let results = (outs Ptr_PtrLikeType:$result);
380+
let assemblyFormat = [{
381+
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
382+
}];
383+
let hasFolder = 1;
384+
let extraClassDeclaration = [{
385+
/// `ViewLikeOp::getViewSource` method.
386+
Value getViewSource() { return getBase(); }
387+
388+
/// Returns the ptr type of the operation.
389+
ptr::PtrType getPtrType();
390+
}];
391+
}
392+
364393
//===----------------------------------------------------------------------===//
365394
// ScatterOp
366395
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Ptr/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
3333
MLIRIR
3434
MLIRDataLayoutInterfaces
3535
MLIRMemorySlotInterfaces
36+
MLIRInferTypeOpInterface
3637
MLIRViewLikeInterface
3738
MLIRPtrMemorySpaceInterfaces
3839
)

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,45 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
346346
return nullptr;
347347
}
348348

349+
LogicalResult PtrAddOp::inferReturnTypes(
350+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
351+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
352+
SmallVectorImpl<Type> &inferredReturnTypes) {
353+
// Get the base pointer and offset types.
354+
Type baseType = operands[0].getType();
355+
Type offsetType = operands[1].getType();
356+
357+
auto offTy = dyn_cast<ShapedType>(offsetType);
358+
if (!offTy) {
359+
// If the offset isn't shaped, the result is always the base type.
360+
inferredReturnTypes.push_back(baseType);
361+
return success();
362+
}
363+
auto baseTy = dyn_cast<ShapedType>(baseType);
364+
if (!baseTy) {
365+
// Base isn't shaped, but offset is, use the ShapedType from offset with the
366+
// base pointer as element type.
367+
inferredReturnTypes.push_back(offTy.clone(baseType));
368+
return success();
369+
}
370+
371+
// Both are shaped, their shape must match.
372+
if (offTy.getShape() != baseTy.getShape()) {
373+
if (location)
374+
mlir::emitError(*location) << "shapes of base and offset must match";
375+
return failure();
376+
}
377+
378+
// Make sure they are the same kind of shaped type.
379+
if (baseType.getTypeID() != offsetType.getTypeID()) {
380+
if (location)
381+
mlir::emitError(*location) << "the shaped containers type must match";
382+
return failure();
383+
}
384+
inferredReturnTypes.push_back(baseType);
385+
return success();
386+
}
387+
349388
//===----------------------------------------------------------------------===//
350389
// ToPtrOp
351390
//===----------------------------------------------------------------------===//

mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
1717
// CHECK: }
1818
func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
19-
%0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
20-
%1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
21-
%2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
22-
%3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
19+
%0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
20+
%1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
21+
%2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
22+
%3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
2323
return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
2424
}
2525

@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
263263
%0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
264264
%1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
265265
%2 = ptr.type_offset f32 : index
266-
%3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
266+
%3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
267267
%4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
268268
return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
269269
}
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
313313
%0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
314314
%1 = ptr.type_offset f32 : index
315315
%2 = arith.muli %1, %arg1 : index
316-
%3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
316+
%3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
317317
return %3 : !ptr.ptr<#ptr.generic_space>
318318
}

mlir/test/Dialect/Ptr/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
5454
ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
5555
return
5656
}
57+
58+
// -----
59+
60+
func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
61+
// expected-error@+1 {{the shaped containers type must match}}
62+
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
63+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
64+
}
65+
66+
// -----
67+
68+
func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
69+
// expected-error@+1 {{shapes of base and offset must match}}
70+
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
71+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
72+
}

mlir/test/Dialect/Ptr/ops.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
126126
ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
127127
return %0 : vector<4xf32>
128128
}
129+
130+
/// Test ptr_add with shaped operands (vectors)
131+
func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
132+
%res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
133+
%res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
134+
%res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
135+
%res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
136+
%res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
137+
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
138+
}
139+
140+
/// Test ptr_add with shaped operands (tensors)
141+
func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
142+
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
143+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
144+
}
145+
146+
/// Test ptr_add with 2D tensors
147+
func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
148+
%res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
149+
%res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
150+
return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
151+
}
152+
153+
/// Test ptr_add with scalar base and shaped offsets (vectors)
154+
func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
155+
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
156+
%res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
157+
%res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
158+
%res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
159+
%res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
160+
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
161+
}
162+
163+
/// Test ptr_add with scalar base and shaped offsets (tensors)
164+
func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
165+
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
166+
%res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
167+
%res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
168+
%res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
169+
%res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
170+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
171+
}
172+
173+
/// Test ptr_add with shaped base and scalar offset (vectors)
174+
func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
175+
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
176+
%res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
177+
%res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
178+
%res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
179+
%res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
180+
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
181+
}
182+
183+
/// Test ptr_add with shaped base and scalar offset (tensors)
184+
func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
185+
%res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
186+
%res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
187+
%res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
188+
%res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
189+
%res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
190+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
191+
}

mlir/test/Target/LLVMIR/ptr.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
203203
ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
204204
llvm.return
205205
}
206+
207+
// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
208+
// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
209+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
210+
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
211+
// CHECK-NEXT: }
212+
llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
213+
%res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
214+
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
215+
}
216+
217+
// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
218+
// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
219+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
220+
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
221+
// CHECK-NEXT: }
222+
llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
223+
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
224+
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
225+
}
226+
227+
// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
228+
// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
229+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
230+
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
231+
// CHECK-NEXT: }
232+
llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
233+
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
234+
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
235+
}

0 commit comments

Comments
 (0)