@@ -406,6 +406,8 @@ struct vk_device_struct {
406
406
bool subgroup_ballot;
407
407
bool subgroup_clustered;
408
408
bool multi_add;
409
+ bool shader_int64;
410
+ bool buffer_device_address;
409
411
410
412
bool add_rms_fusion;
411
413
uint32_t partials_binding_alignment;
@@ -650,6 +652,7 @@ struct vk_buffer_struct {
650
652
vk::MemoryPropertyFlags memory_property_flags;
651
653
void * ptr;
652
654
size_t size = 0;
655
+ vk::DeviceAddress bda_addr {};
653
656
654
657
vk_device device;
655
658
@@ -982,6 +985,7 @@ struct vk_op_argsort_push_constants {
982
985
};
983
986
984
987
struct vk_op_im2col_push_constants {
988
+ uint64_t dst_addr;
985
989
uint32_t batch_offset; uint32_t offset_delta;
986
990
uint32_t IC;
987
991
uint32_t IW; uint32_t IH;
@@ -995,6 +999,7 @@ struct vk_op_im2col_push_constants {
995
999
};
996
1000
997
1001
struct vk_op_im2col_3d_push_constants {
1002
+ uint64_t dst_addr;
998
1003
uint32_t nb10;
999
1004
uint32_t nb11;
1000
1005
uint32_t nb12;
@@ -1946,10 +1951,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
1946
1951
return buf;
1947
1952
}
1948
1953
1954
+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
1955
+ vk::MemoryAllocateFlags mem_flags {};
1956
+ if (device->buffer_device_address) {
1957
+ usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
1958
+ mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
1959
+ }
1960
+
1949
1961
vk::BufferCreateInfo buffer_create_info{
1950
1962
vk::BufferCreateFlags(),
1951
1963
size,
1952
- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst ,
1964
+ usage_flags ,
1953
1965
vk::SharingMode::eExclusive,
1954
1966
0,
1955
1967
nullptr,
@@ -1961,6 +1973,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
1961
1973
1962
1974
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
1963
1975
1976
+ const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
1977
+
1964
1978
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
1965
1979
const auto & req_flags = *it;
1966
1980
@@ -1972,7 +1986,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
1972
1986
buf->memory_property_flags = req_flags;
1973
1987
1974
1988
try {
1975
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
1989
+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
1976
1990
break;
1977
1991
} catch (const vk::SystemError& e) {
1978
1992
// loop and retry
@@ -2000,6 +2014,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
2000
2014
buf->device = device;
2001
2015
buf->size = size;
2002
2016
2017
+ if (device->buffer_device_address) {
2018
+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2019
+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
2020
+ }
2021
+
2003
2022
#ifdef GGML_VULKAN_MEMORY_DEBUG
2004
2023
device->memory_logger->log_allocation(buf, size);
2005
2024
#endif
@@ -3447,14 +3466,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
3447
3466
3448
3467
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
3449
3468
3450
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3451
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3452
- if (device->float_controls_rte_fp16) {
3453
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3454
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3469
+ #define IM2COL(bda) \
3470
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3471
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3472
+ if (device->float_controls_rte_fp16) { \
3473
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3474
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3475
+ } else { \
3476
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3477
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3478
+ }
3479
+ if (device->shader_int64 && device->buffer_device_address) {
3480
+ IM2COL(_bda)
3455
3481
} else {
3456
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3457
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3482
+ IM2COL()
3458
3483
}
3459
3484
3460
3485
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -3933,6 +3958,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
3933
3958
device->vendor_id != VK_VENDOR_ID_INTEL &&
3934
3959
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
3935
3960
3961
+ device->shader_int64 = device_features2.features.shaderInt64;
3962
+ device->buffer_device_address = vk12_features.bufferDeviceAddress;
3963
+
3936
3964
if (device->subgroup_size_control) {
3937
3965
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
3938
3966
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -9290,7 +9318,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
9290
9318
9291
9319
const uint32_t pelements = OW * KW * KH;
9292
9320
9321
+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9322
+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9323
+
9324
+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9325
+
9293
9326
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
9327
+ dst_addr,
9294
9328
batch_offset, offset_delta,
9295
9329
IC, IW, IH, OW, OH, KW, KH,
9296
9330
pelements,
@@ -9326,8 +9360,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
9326
9360
const int64_t OH = ne2;
9327
9361
const int64_t OW = ne1;
9328
9362
9363
+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9364
+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9365
+
9366
+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9367
+
9329
9368
vk_op_im2col_3d_push_constants pc {};
9330
9369
9370
+ pc.dst_addr = dst_addr;
9331
9371
pc.nb10 = nb10 / ggml_type_size(src1->type);
9332
9372
pc.nb11 = nb11 / ggml_type_size(src1->type);
9333
9373
pc.nb12 = nb12 / ggml_type_size(src1->type);
0 commit comments