Skip to content

Can you give the visualization code for the uncertainty estimation for each pixel point? #1

@hejiaxiang1

Description

@hejiaxiang1

I tried to visualize the var part, but the output has no useful information. My modified /sd/dpmsolver_skipUQ.py code is as follows:

    #########   start sample  ########## 
    c = model.get_learned_conditioning(opt.prompt)
    c = torch.concat(opt.sample_batch_size * [c], dim=0)
    exp_dir = f'./dpm_solver_2_exp/skipUQ/{opt.prompt}_train{opt.train_la_data_size}_step{opt.timesteps}_S{opt.mc_size}/'
    os.makedirs(exp_dir, exist_ok=True)
    total_n_samples = opt.total_n_samples
    if total_n_samples % opt.sample_batch_size != 0:
        raise ValueError("Total samples for sampling must be divided exactly by opt.sample_batch_size, but got {} and {}".format(total_n_samples, opt.sample_batch_size))
    n_rounds = total_n_samples // opt.sample_batch_size
    var_sum = torch.zeros((opt.sample_batch_size, n_rounds)).to(device)
    sample_x = []
    var_x = [] # add
    img_id = 1000000
    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                for loop in tqdm(
                    range(n_rounds), desc="Generating image samples for FID evaluation."
                ):
                    
                    xT, timestep, mc_sample_size  = torch.randn([opt.sample_batch_size, opt.C, opt.H // opt.f, opt.W // opt.f], device=device), opt.timesteps//2, opt.mc_size
                    T = t_seq[timestep]
                    if uq_array[timestep] == True:
                        xt_next = xT
                        exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
                        eps_mu_t_next, eps_var_t_next = custom_ld(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c) 
                        cov_xt_next_epst_next = torch.zeros_like(xT).to(device)
                        _, model_s1, _ = conditioned_update(ns, xt_next, T, t_seq[timestep-1], custom_ld, eps_mu_t_next, pre_wuq=True, r1=0.5, c=c)
                        list_eps_mu_t_next_i = torch.unsqueeze(model_s1, dim=0)
                    else:
                        xt_next = xT
                        exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
                        eps_mu_t_next = custom_ld.accurate_forward(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c)
                    
                    ####### Start skip UQ sampling  ######
                    for timestep in range(opt.timesteps//2, 0, -1):

                        if uq_array[timestep] == True:
                            xt = xt_next
                            exp_xt, var_xt = exp_xt_next, var_xt_next
                            eps_mu_t, eps_var_t, cov_xt_epst = eps_mu_t_next, eps_var_t_next, cov_xt_next_epst_next
                            mc_eps_exp_t = torch.mean(list_eps_mu_t_next_i, dim=0)
                        else: 
                            xt = xt_next
                            exp_xt, var_xt = exp_xt_next, var_xt_next
                            eps_mu_t = eps_mu_t_next
                        
                        s, t = t_seq[timestep], t_seq[timestep-1]
                        if uq_array[timestep] == True:
                            eps_t= sample_from_gaussion(eps_mu_t, eps_var_t)
                            xt_next, _ , model_s1_var = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
                            exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, pre_wuq=uq_array[timestep], mc_eps_exp_s1=mc_eps_exp_t)
                            var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep], cov_xt_epst= cov_xt_epst, var_epst=model_s1_var)
                            # decide whether to see xt_next as a random variable
                            if uq_array[timestep-1] == True:
                                list_xt_next_i, list_eps_mu_t_next_i=[], []
                                s_next = t_seq[timestep-1]
                                t_next = t_seq[timestep-2]
                                lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
                                h_next = lambda_t_next - lambda_s_next
                                lambda_s1_next = lambda_s_next + 0.5 * h_next
                                s1_next = ns.inverse_lambda(lambda_s1_next)
                                sigma_s1_next = ns.marginal_std(s1_next)
                                log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
                                phi_11_next = torch.expm1(0.5*h_next)

                                for _ in range(mc_sample_size):
                                    
                                    var_xt_next = torch.clamp(var_xt_next, min=0)
                                    xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
                                    list_xt_next_i.append(xt_next_i)
                                    model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
                                    xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
                                                                    torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
                                    model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
                                    list_eps_mu_t_next_i.append(model_u_i)

                                eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
                                list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
                                list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
                                cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
                            else:
                                eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)

                        else:
                            xt_next, model_s1 = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_mu_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
                            exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, exp_s1= model_s1, pre_wuq=uq_array[timestep])
                            var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep])
                            if uq_array[timestep-1] == True:
                                list_xt_next_i, list_eps_mu_t_next_i=[], []
                                s_next = t_seq[timestep-1]
                                t_next = t_seq[timestep-2]
                                lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
                                h_next = lambda_t_next - lambda_s_next
                                lambda_s1_next = lambda_s_next + 0.5 * h_next
                                s1_next = ns.inverse_lambda(lambda_s1_next)
                                sigma_s1_next = ns.marginal_std(s1_next)
                                log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
                                phi_11_next = torch.expm1(0.5*h_next)

                                for _ in range(mc_sample_size):
                                    
                                    var_xt_next = torch.clamp(var_xt_next, min=0)
                                    xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
                                    list_xt_next_i.append(xt_next_i)
                                    model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
                                    xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
                                                                    torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
                                    model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
                                    list_eps_mu_t_next_i.append(model_u_i)

                                eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
                                list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
                                list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
                                cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
                            else:
                                eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)

                    ####### Save variance and sample image  ######         
                    var_sum[:, loop] = var_xt_next.sum(dim=(1,2,3))
                    x_samples = model.decode_first_stage(xt_next) # 
                    # var_xt_next = model.decode_first_stage(var_xt_next)# add
                    x = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                    # os.makedirs(os.path.join(exp_dir, 'sam/'), exist_ok=True)
                    # for i in range(x.shape[0]):
                    #     path = os.path.join(exp_dir, 'sam/', f"{img_id}.png")
                    #     tvu.save_image(x.cpu()[i].float(), path)
                    #     img_id += 1
                    sample_x.append(x)
                    var_x.append(var_xt_next) # add

                sample_x = torch.concat(sample_x, dim=0)
                var_x = torch.concat(var_x, dim=0)# add
                var = []
                for j in range(n_rounds):
                    var.append(var_sum[:, j])
                var = torch.concat(var, dim=0)
                sorted_var, sorted_indices = torch.sort(var, descending=True)
                reordered_sample_x = torch.index_select(sample_x, dim=0, index=sorted_indices.int())
                grid_sample_x = tvu.make_grid(reordered_sample_x, nrow=8, padding=2)
                tvu.save_image(grid_sample_x.cpu().float(), os.path.join(exp_dir, "sorted_sample.png"))

                print(f'Sampling {total_n_samples} images in {exp_dir}')
                torch.save(var_sum.cpu(), os.path.join(exp_dir, 'var_sum.pt'))

                var_x = var_x.mean(dim=1, keepdim=True) # add
                reordered_var_x = torch.index_select(var_x, dim=0, index=sorted_indices.int()) # add
                grid_var_x = tvu.make_grid(reordered_var_x, nrow=12, padding=1, normalize=True) # add
                tvu.save_image(grid_var_x.cpu().float(), os.path.join(exp_dir, "sorted_var.png")) # add

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions