diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 3859895b6e..2030164572 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -474,6 +474,7 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, + GGML_OP_UNFOLD_1D, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, GGML_OP_ARANGE, @@ -1708,6 +1709,14 @@ extern "C" { float p0, float p1); + + GGML_API struct ggml_tensor * ggml_unfold_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int s); + + // nearest interpolate // multiplies ne0 and ne1 by scale factor // used in stable-diffusion diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index d0a754ee11..595b012b47 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -29,6 +29,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/unfold1d.cuh" #include #include @@ -2288,6 +2289,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PAD: ggml_cuda_op_pad(ctx, dst); break; + case GGML_OP_UNFOLD_1D: + ggml_cuda_op_unfold_1d(ctx, dst); + break; case GGML_OP_ARANGE: ggml_cuda_op_arange(ctx, dst); break; @@ -2895,6 +2899,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_UNFOLD_1D: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: diff --git a/src/ggml-cuda/unfold1d.cu b/src/ggml-cuda/unfold1d.cu new file mode 100644 index 0000000000..ffd20e6eda --- /dev/null +++ b/src/ggml-cuda/unfold1d.cu @@ -0,0 +1,44 @@ +#include "unfold1d.cuh" + +static __global__ void unfold_1d_f32(const float * x, float * dst, const int s, const int ne0, const int ne1, const int ne2, + const int ne3, const int ne00, const int ne01, const int ne02, const int ne03) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0 * ne1 * ne2 * ne3) { + return; + } + + const int i3 = nidx/(ne0 * ne1 * ne2); + const int i2 = (nidx - i3*ne0*ne1*ne2 )/ (ne0*ne1); + const int i1 = (nidx - i3*ne0*ne1*ne2 - i2*ne1*ne0) / ne0; + const int i0 = nidx - i3*ne0*ne1*ne2 - i2*ne1*ne0 - i1*ne0; + + const int src_idx = i3 *(ne00*ne01) + i2 * (ne00) + i1*s + i0; + + dst[nidx] = x[src_idx]; +} + +static void unfold_1d_f32_cuda(const float * x, float * dst, const int s, + const int ne0, const int ne1, const int ne2, const int ne3, + const int ne00, const int ne01, const int ne02, const int ne03, cudaStream_t stream) { + int num_blocks = ((ne0 * ne1 * ne2 * ne3) + CUDA_UNFOLD_1D_BLOCK_SIZE - 1) / CUDA_UNFOLD_1D_BLOCK_SIZE; + + unfold_1d_f32<<>>(x, dst, s, ne0, ne1, ne2, ne3, ne00, ne01, ne02, ne03); +} + +void ggml_cuda_op_unfold_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1); // only up to 3 dimensions for input tensor + + const int32_t * opts = (const int32_t *)dst->op_params; + const int s = opts[1]; + + unfold_1d_f32_cuda(src0_d, dst_d, s, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], stream); +} diff --git a/src/ggml-cuda/unfold1d.cuh b/src/ggml-cuda/unfold1d.cuh new file mode 100644 index 0000000000..1a7766de25 --- /dev/null +++ b/src/ggml-cuda/unfold1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_UNFOLD_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_unfold_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml.c b/src/ggml.c index 5025ec23b3..4c37a29831 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -2696,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2752,6 +2752,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "pool_1d(x)", "pool_2d(x)", "upscale(x)", + "unfold_1d(x)", "pad(x)", "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", @@ -2784,7 +2785,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6852,6 +6853,43 @@ struct ggml_tensor * ggml_upscale_ext( return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); } + +// ggml_unfold_1d + +struct ggml_tensor * ggml_unfold_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int s) { + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + GGML_ASSERT(a->ne[3] == 1); // we only allow up to 3d input tensors, since this operations adds a dimension + + GGML_ASSERT((a->ne[0] - k) % s == 0);// are the stride and kernel size valid given the unfold dimension + + + const int64_t ne[4] = { k, ((a->ne[0] - k) / s) + 1 ,a->ne[1], a->ne[2]}; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { k, s }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_UNFOLD_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + + + // ggml_pad struct ggml_tensor * ggml_pad( @@ -15445,6 +15483,51 @@ static void ggml_compute_forward_upscale( } } +// ggml_compute_forward_unfold_1d + +static void ggml_compute_forward_unfold_1d( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + float * src0_ptr = (float *) src0->data; + + + const int32_t * opts = (const int32_t *)dst->op_params; + const int s = opts[1]; + + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const int64_t src_idx = i3 *(ne00*ne01) + i2 * (ne00) + i1*s + i0; + + dst_ptr[dst_idx] = src0_ptr[src_idx]; + + } + } + } + } +} + // ggml_compute_forward_pad @@ -17453,6 +17536,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_upscale(params, tensor); } break; + case GGML_OP_UNFOLD_1D: + { + ggml_compute_forward_unfold_1d(params, tensor); + } break; case GGML_OP_PAD: { ggml_compute_forward_pad(params, tensor); @@ -18463,6 +18550,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_UNFOLD_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_PAD: { GGML_ASSERT(false); // TODO: not implemented @@ -19206,6 +19297,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = n_threads; } break; + case GGML_OP_UNFOLD_1D: + { + n_tasks = 1; + } break; case GGML_OP_PAD: { n_tasks = n_threads; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0759596415..ffe6be8882 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -403,6 +403,14 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml) add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") +# +# test-unfold-1d + +set(TEST_TARGET test-unfold-1d) +add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") # # test-mul-mat diff --git a/tests/test-unfold-1d.cpp b/tests/test-unfold-1d.cpp new file mode 100644 index 0000000000..e447194fd1 --- /dev/null +++ b/tests/test-unfold-1d.cpp @@ -0,0 +1,292 @@ +#include "ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" + +//#define GGML_USE_CUBLAS + + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a_0; + struct ggml_tensor * a_1; + + + + + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +void load_model(test_model & model, bool use_gpu = false) { + + + + + float data[1024]; + for (int i = 0; i < 1024; ++i) { + data[i] = (float)i; + } + + + + + size_t buffer_size = 0; + { + buffer_size += 2 * 6 * ggml_type_size(GGML_TYPE_F32); // tensor a_0 + + buffer_size += 2 * 2 * 4 * ggml_type_size(GGML_TYPE_F32); // tensor a_1 + + buffer_size += 1024; + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a_0 = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, 6,2); + model.a_1 = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 4,2,2); + + + // create a allocator + ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a_0); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a_0->data, data, ggml_nbytes(model.a_0)); + } else { + ggml_backend_tensor_set(model.a_0, data, 0, ggml_nbytes(model.a_0)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a_1); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a_1->data, data, ggml_nbytes(model.a_1)); + } else { + ggml_backend_tensor_set(model.a_1, data, 0, ggml_nbytes(model.a_1)); + } + + +} + +struct ggml_cgraph * build_graph(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int k = 3; + int s = 3; + + struct ggml_tensor* pad_res_0 = ggml_unfold_1d(ctx0, model.a_0, k, s); + ggml_set_name(pad_res_0, "pad_res_0"); + ggml_build_forward_expand(gf, pad_res_0); + + + k = 2; + s = 1; + + struct ggml_tensor* pad_res_1 = ggml_unfold_1d(ctx0, model.a_1, k, s); + ggml_set_name(pad_res_1, "pad_res_1"); + ggml_build_forward_expand(gf, pad_res_1); + + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph* compute_graph(const test_model & model, ggml_gallocr_t allocr) { + struct ggml_cgraph * gf = build_graph(model); + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_graph_compute(model.backend, gf); + + //ggml_graph_print(gf); + + return gf; +} + +int main(void) +{ + ggml_time_init(); + + test_model model; + load_model(model, true); + + ggml_gallocr_t allocr = NULL; + + { + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + struct ggml_cgraph * gf_res = compute_graph(model, allocr); + + struct ggml_tensor * pad_res_0 = NULL; + + for(int i = 0; i < gf_res->n_nodes; i++) { + if(strcmp(ggml_get_name(gf_res->nodes[i]), "pad_res_0") == 0) { + pad_res_0 = gf_res->nodes[i]; + } + } + + float* pad_data_0 = new float[ggml_nelements(pad_res_0)]; + + ggml_backend_tensor_get(pad_res_0, pad_data_0, 0, ggml_nbytes(pad_res_0)); + + const int n_pad_test_0 = 2 *2 * 3; + + float expected_pad_reflect_0[n_pad_test_0] = {0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0}; + + struct ggml_tensor * pad_res_1 = NULL; + + for(int i = 0; i < gf_res->n_nodes; i++) { + if(strcmp(ggml_get_name(gf_res->nodes[i]), "pad_res_1") == 0) { + pad_res_1 = gf_res->nodes[i]; + } + } + + float* pad_data_1 = new float[ggml_nelements(pad_res_1)]; + + ggml_backend_tensor_get(pad_res_1, pad_data_1, 0, ggml_nbytes(pad_res_1)); + + const int n_pad_test_1 = 3* 2 *2 *2; + + float expected_pad_reflect_1[n_pad_test_1] = {0.0,1.0,1.0,2.0,2.0,3.0,4.0,5.0,5.0,6.0,6.0,7.0,8.0,9.0,9.0,10.0,10.0,11.0,12.0,13.0,13.0,14.0,14.0,15.0}; + + printf("\nPerforming test:\n"); + + bool passed = true; + for(int i = 0; i < n_pad_test_0; i++) { + if( + pad_data_0[i] != expected_pad_reflect_0[i]) { + std::cout << "index: " << i << std::endl; + std::cout << "expected: " << expected_pad_reflect_0[i] << std::endl; + std::cout << "actual: " << pad_data_0[i] << std::endl; + passed = false; + break; + } + } + + printf("ggml_pad_ext (%d): %s\n", (int) ggml_nelements(pad_res_0), passed && (ggml_nelements(pad_res_0) == n_pad_test_0) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + + passed = true; + for(int i = 0; i < n_pad_test_1; i++) { + if( + pad_data_1[i] != expected_pad_reflect_1[i]) { + std::cout << "index: " << i << std::endl; + std::cout << "expected: " << expected_pad_reflect_1[i] << std::endl; + std::cout << "actual: " << pad_data_1[i] << std::endl; + passed = false; + break; + } + } + + printf("ggml_pad_ext (%d): %s\n", (int) ggml_nelements(pad_res_1), passed && (ggml_nelements(pad_res_1) == n_pad_test_1) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + return 0; +}