diff --git a/README.md b/README.md index 399199a..cb686cd 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Usage: zimage-ncnn-vulkan -p prompt -o outfile [options]... -r random-seed random seed (default=rand) -m model-path z-image model path (default=z-image-turbo) -g gpu-id gpu device to use (-1=cpu, default=auto) + -b batch-size batched generation (default=1) ``` If you encounter a crash or error, try upgrading your GPU driver: diff --git a/src/main.cpp b/src/main.cpp index d27bec5..d6dac0b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -106,6 +106,7 @@ static void print_usage() fprintf(stdout, " -r random-seed random seed (default=rand)\n"); fprintf(stdout, " -m model-path z-image model path (default=z-image-turbo)\n"); fprintf(stdout, " -g gpu-id gpu device to use (-1=cpu, default=auto)\n"); + fprintf(stdout, " -b batch-size batched generation (default=1)\n"); } #if _WIN32 @@ -125,13 +126,14 @@ int main(int argc, char** argv) int seed = rand(); path_t model = PATHSTR("z-image-turbo"); int gpuid = 233; + int batch = 1; // parse cli args { #if _WIN32 setlocale(LC_ALL, ""); wchar_t opt; - while ((opt = getopt(argc, argv, L"p:n:o:s:l:r:m:g:h")) != (wchar_t)-1) + while ((opt = getopt(argc, argv, L"p:n:o:s:l:r:m:g:b:h")) != (wchar_t)-1) { switch (opt) { @@ -168,6 +170,9 @@ int main(int argc, char** argv) case L'g': gpuid = _wtoi(optarg); break; + case L'b': + batch = _wtoi(optarg); + break; case L'h': default: print_usage(); @@ -176,7 +181,7 @@ int main(int argc, char** argv) } #else // _WIN32 int opt; - while ((opt = getopt(argc, argv, "p:n:o:s:l:r:m:g:h")) != -1) + while ((opt = getopt(argc, argv, "p:n:o:s:l:r:m:g:b:h")) != -1) { switch (opt) { @@ -213,6 +218,9 @@ int main(int argc, char** argv) case 'g': gpuid = atoi(optarg); break; + case 'b': + batch = atoi(optarg); + break; case 'h': default: print_usage(); @@ -250,6 +258,12 @@ int main(int argc, char** argv) return -1; } + if (batch <= 0) + { + fprintf(stderr, "batch must be > 0 but got %d\n", batch); + return -1; + } + float guidance_scale; float scheduler_shift; if (model.find(PATHSTR("z-image-turbo")) != path_t::npos) @@ -296,6 +310,7 @@ int main(int argc, char** argv) fprintf(stderr, "steps = %d\n", steps); fprintf(stderr, "seed = %d\n", seed); fprintf(stderr, "gpu-id = %d\n", gpuid); + fprintf(stderr, "batch = %d\n", batch); const bool apply_cfg = guidance_scale > 0.f; @@ -332,6 +347,17 @@ int main(int argc, char** argv) NCNN_LOGE("vae_tile_size = %d x %d", vae_tile_width, vae_tile_height); } + if (batch > 1) + { + path_t filename = get_file_name_without_extension(outpath); + path_t ext = get_file_extension(outpath); +#if _WIN32 + fwprintf(stderr, L"batch generation enabled. output-path will be %ls-0.%ls %ls-1.%ls %ls-2.%ls ...\n", filename.c_str(), ext.c_str(), filename.c_str(), ext.c_str(), filename.c_str(), ext.c_str()); +#else + fprintf(stderr, "batch generation enabled. output-path will be %s-0.%s %s-1.%s %s-2.%s ...\n", filename.c_str(), ext.c_str(), filename.c_str(), ext.c_str(), filename.c_str(), ext.c_str()); +#endif + } + // tokenizer std::vector input_ids; std::vector neg_input_ids; @@ -362,13 +388,16 @@ int main(int argc, char** argv) } } - // prepare latent - ncnn::Mat latent; - ZImage::generate_latent(width, height, seed, latent); + // prepare latents + std::vector latents(batch); + for (int b = 0; b < batch; b++) + { + ZImage::generate_latent(width, height, seed + b, latents[b]); + } const int patch_size = 2; - const int num_patches_w = latent.w / patch_size; - const int num_patches_h = latent.h / patch_size; + const int num_patches_w = latents[0].w / patch_size; + const int num_patches_h = latents[0].h / patch_size; fprintf(stderr, "num_patches = %d x %d\n", num_patches_w, num_patches_h); @@ -429,10 +458,6 @@ int main(int argc, char** argv) } } - // patchify - ncnn::Mat x; - ZImage::patchify(latent, x); - // prepare timesteps std::vector sigmas; std::vector timesteps; @@ -466,140 +491,184 @@ int main(int argc, char** argv) all_final_layer.load(model, opt); - for (int z = 0; z < steps; z++) + for (int b = 0; b < batch; b++) { - ncnn::Mat t_embed = t_embeds.row_range(z, 1).clone(); + // patchify + ncnn::Mat x; + ZImage::patchify(latents[b], x); - // all_x_embedder - ncnn::Mat x_embed; - all_x_embedder.process(x, x_embed); + for (int z = 0; z < steps; z++) + { + ncnn::Mat t_embed = t_embeds.row_range(z, 1).clone(); - // noise_refiner - ncnn::Mat x_embed_refine; - noise_refiner.process(x_embed, x_cos, x_sin, t_embed, x_embed_refine); + // all_x_embedder + ncnn::Mat x_embed; + all_x_embedder.process(x, x_embed); - // concat x_embed_refine and cap_refine - ncnn::Mat unified_embed; - ZImage::concat_along_h(x_embed_refine, cap_refine, unified_embed); + // noise_refiner + ncnn::Mat x_embed_refine; + noise_refiner.process(x_embed, x_cos, x_sin, t_embed, x_embed_refine); - ncnn::Mat neg_unified_embed; - if (apply_cfg) - { - ZImage::concat_along_h(x_embed_refine, neg_cap_refine, neg_unified_embed); - } + // concat x_embed_refine and cap_refine + ncnn::Mat unified_embed; + ZImage::concat_along_h(x_embed_refine, cap_refine, unified_embed); - // unified - ncnn::Mat unified; - unified_refiner.process(unified_embed, unified_cos, unified_sin, t_embed, unified); + ncnn::Mat neg_unified_embed; + if (apply_cfg) + { + ZImage::concat_along_h(x_embed_refine, neg_cap_refine, neg_unified_embed); + } - ncnn::Mat neg_unified; - if (apply_cfg) - { - unified_refiner.process(neg_unified_embed, neg_unified_cos, neg_unified_sin, t_embed, neg_unified); - } + // unified + ncnn::Mat unified; + unified_refiner.process(unified_embed, unified_cos, unified_sin, t_embed, unified); - // all_final_layer - ncnn::Mat unified_final; - all_final_layer.process(unified, t_embed, unified_final); + ncnn::Mat neg_unified; + if (apply_cfg) + { + unified_refiner.process(neg_unified_embed, neg_unified_cos, neg_unified_sin, t_embed, neg_unified); + } - ncnn::Mat neg_unified_final; - if (apply_cfg) - { - all_final_layer.process(neg_unified, t_embed, neg_unified_final); - } + // all_final_layer + ncnn::Mat unified_final; + all_final_layer.process(unified, t_embed, unified_final); - if (apply_cfg) - { - // apply cfg - const int total = x.total(); - for (int i = 0; i < total; i++) + ncnn::Mat neg_unified_final; + if (apply_cfg) { - float pos = unified_final[i]; - float neg = neg_unified_final[i]; + all_final_layer.process(neg_unified, t_embed, neg_unified_final); + } - unified_final[i] = pos + guidance_scale * (pos - neg); + if (apply_cfg) + { + // apply cfg + const int total = x.total(); + for (int i = 0; i < total; i++) + { + float pos = unified_final[i]; + float neg = neg_unified_final[i]; + + unified_final[i] = pos + guidance_scale * (pos - neg); + } } - } - // euler scheduler step - { - const float dt = sigmas[z + 1] - sigmas[z]; + // euler scheduler step + { + const float dt = sigmas[z + 1] - sigmas[z]; - const int total = x.total(); - for (int i = 0; i < total; i++) + const int total = x.total(); + for (int i = 0; i < total; i++) + { + x[i] = x[i] - dt * unified_final[i]; + } + } + + if (batch > 1) + { + fprintf(stderr, "step %d/%d of image %d/%d done\n", z + 1, steps, b + 1, batch); + } + else { - x[i] = x[i] - dt * unified_final[i]; + fprintf(stderr, "step %d/%d done\n", z + 1, steps); } } - fprintf(stderr, "step %d done\n", z); + // unpatchify + ZImage::unpatchify(x, latents[b]); } } - // unpatchify - ZImage::unpatchify(x, latent); - - // vae decode - ncnn::Mat outimage; + // vae decode and save image { - const float vae_scaling_factor = 0.3611f; - const float vae_shift_factor = 0.1159f; - - for (int i = 0; i < latent.total(); i++) - { - latent[i] = latent[i] / vae_scaling_factor + vae_shift_factor; - } - const bool use_vae_tiled = vae_tile_width < width || vae_tile_height < height; ZImage::VAE vae; vae.load(model, use_vae_tiled, opt); - if (use_vae_tiled) - { - vae.process_tiled(latent, vae_tile_width, vae_tile_height, outimage); - } - else + for (int b = 0; b < batch; b++) { - vae.process(latent, outimage); - } - } + // vae decode + ncnn::Mat outimage; + { + const float vae_scaling_factor = 0.3611f; + const float vae_shift_factor = 0.1159f; - // save image - { - int success = 0; + for (int i = 0; i < latents[b].total(); i++) + { + latents[b][i] = latents[b][i] / vae_scaling_factor + vae_shift_factor; + } - path_t ext = get_file_extension(outpath); + if (use_vae_tiled) + { + vae.process_tiled(latents[b], vae_tile_width, vae_tile_height, outimage); + } + else + { + vae.process(latents[b], outimage); + } + } + + if (batch > 1) + { + fprintf(stderr, "vae of image %d/%d done\n", b + 1, batch); + } + else + { + fprintf(stderr, "vae done\n"); + } + + // save image + { + int success = 0; + + path_t ext = get_file_extension(outpath); + + path_t outpath_b = outpath; + if (batch > 1) + { + path_t filename = get_file_name_without_extension(outpath); - if (ext == PATHSTR("webp") || ext == PATHSTR("WEBP")) - { - success = webp_save(outpath.c_str(), outimage.w, outimage.h, outimage.elempack, (const unsigned char*)outimage.data); - } - else if (ext == PATHSTR("png") || ext == PATHSTR("PNG")) - { #if _WIN32 - success = wic_encode_image(outpath.c_str(), outimage.w, outimage.h, outimage.elempack, outimage.data); + wchar_t hnd[256]; + swprintf(hnd, 256, L"-%d.", b); #else - success = png_save(outpath.c_str(), outimage.w, outimage.h, outimage.elempack, (const unsigned char*)outimage.data); + char hnd[256]; + sprintf(hnd, "-%d.", b); #endif - } - else if (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG")) - { + outpath_b = filename + hnd + ext; + } + + if (ext == PATHSTR("webp") || ext == PATHSTR("WEBP")) + { + success = webp_save(outpath_b.c_str(), outimage.w, outimage.h, outimage.elempack, (const unsigned char*)outimage.data); + } + else if (ext == PATHSTR("png") || ext == PATHSTR("PNG")) + { #if _WIN32 - success = wic_encode_jpeg_image(outpath.c_str(), outimage.w, outimage.h, outimage.elempack, outimage.data); + success = wic_encode_image(outpath_b.c_str(), outimage.w, outimage.h, outimage.elempack, outimage.data); #else - success = jpeg_save(outpath.c_str(), outimage.w, outimage.h, outimage.elempack, (const unsigned char*)outimage.data); + success = png_save(outpath_b.c_str(), outimage.w, outimage.h, outimage.elempack, (const unsigned char*)outimage.data); #endif - } + } + else if (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG")) + { +#if _WIN32 + success = wic_encode_jpeg_image(outpath_b.c_str(), outimage.w, outimage.h, outimage.elempack, outimage.data); +#else + success = jpeg_save(outpath_b.c_str(), outimage.w, outimage.h, outimage.elempack, (const unsigned char*)outimage.data); +#endif + } - if (!success) - { + if (!success) + { #if _WIN32 - fwprintf(stderr, L"encode image %ls failed\n", outpath.c_str()); + fwprintf(stderr, L"encode image %ls failed\n", outpath_b.c_str()); #else - fprintf(stderr, "encode image %s failed\n", outpath.c_str()); + fprintf(stderr, "encode image %s failed\n", outpath_b.c_str()); #endif + } + } } }