-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
I tried to visualize the var part, but the output has no useful information. My modified /sd/dpmsolver_skipUQ.py code is as follows:
######### start sample ##########
c = model.get_learned_conditioning(opt.prompt)
c = torch.concat(opt.sample_batch_size * [c], dim=0)
exp_dir = f'./dpm_solver_2_exp/skipUQ/{opt.prompt}_train{opt.train_la_data_size}_step{opt.timesteps}_S{opt.mc_size}/'
os.makedirs(exp_dir, exist_ok=True)
total_n_samples = opt.total_n_samples
if total_n_samples % opt.sample_batch_size != 0:
raise ValueError("Total samples for sampling must be divided exactly by opt.sample_batch_size, but got {} and {}".format(total_n_samples, opt.sample_batch_size))
n_rounds = total_n_samples // opt.sample_batch_size
var_sum = torch.zeros((opt.sample_batch_size, n_rounds)).to(device)
sample_x = []
var_x = [] # add
img_id = 1000000
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
for loop in tqdm(
range(n_rounds), desc="Generating image samples for FID evaluation."
):
xT, timestep, mc_sample_size = torch.randn([opt.sample_batch_size, opt.C, opt.H // opt.f, opt.W // opt.f], device=device), opt.timesteps//2, opt.mc_size
T = t_seq[timestep]
if uq_array[timestep] == True:
xt_next = xT
exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
eps_mu_t_next, eps_var_t_next = custom_ld(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c)
cov_xt_next_epst_next = torch.zeros_like(xT).to(device)
_, model_s1, _ = conditioned_update(ns, xt_next, T, t_seq[timestep-1], custom_ld, eps_mu_t_next, pre_wuq=True, r1=0.5, c=c)
list_eps_mu_t_next_i = torch.unsqueeze(model_s1, dim=0)
else:
xt_next = xT
exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
eps_mu_t_next = custom_ld.accurate_forward(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c)
####### Start skip UQ sampling ######
for timestep in range(opt.timesteps//2, 0, -1):
if uq_array[timestep] == True:
xt = xt_next
exp_xt, var_xt = exp_xt_next, var_xt_next
eps_mu_t, eps_var_t, cov_xt_epst = eps_mu_t_next, eps_var_t_next, cov_xt_next_epst_next
mc_eps_exp_t = torch.mean(list_eps_mu_t_next_i, dim=0)
else:
xt = xt_next
exp_xt, var_xt = exp_xt_next, var_xt_next
eps_mu_t = eps_mu_t_next
s, t = t_seq[timestep], t_seq[timestep-1]
if uq_array[timestep] == True:
eps_t= sample_from_gaussion(eps_mu_t, eps_var_t)
xt_next, _ , model_s1_var = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, pre_wuq=uq_array[timestep], mc_eps_exp_s1=mc_eps_exp_t)
var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep], cov_xt_epst= cov_xt_epst, var_epst=model_s1_var)
# decide whether to see xt_next as a random variable
if uq_array[timestep-1] == True:
list_xt_next_i, list_eps_mu_t_next_i=[], []
s_next = t_seq[timestep-1]
t_next = t_seq[timestep-2]
lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
h_next = lambda_t_next - lambda_s_next
lambda_s1_next = lambda_s_next + 0.5 * h_next
s1_next = ns.inverse_lambda(lambda_s1_next)
sigma_s1_next = ns.marginal_std(s1_next)
log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
phi_11_next = torch.expm1(0.5*h_next)
for _ in range(mc_sample_size):
var_xt_next = torch.clamp(var_xt_next, min=0)
xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
list_xt_next_i.append(xt_next_i)
model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
list_eps_mu_t_next_i.append(model_u_i)
eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
else:
eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)
else:
xt_next, model_s1 = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_mu_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, exp_s1= model_s1, pre_wuq=uq_array[timestep])
var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep])
if uq_array[timestep-1] == True:
list_xt_next_i, list_eps_mu_t_next_i=[], []
s_next = t_seq[timestep-1]
t_next = t_seq[timestep-2]
lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
h_next = lambda_t_next - lambda_s_next
lambda_s1_next = lambda_s_next + 0.5 * h_next
s1_next = ns.inverse_lambda(lambda_s1_next)
sigma_s1_next = ns.marginal_std(s1_next)
log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
phi_11_next = torch.expm1(0.5*h_next)
for _ in range(mc_sample_size):
var_xt_next = torch.clamp(var_xt_next, min=0)
xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
list_xt_next_i.append(xt_next_i)
model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
list_eps_mu_t_next_i.append(model_u_i)
eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
else:
eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)
####### Save variance and sample image ######
var_sum[:, loop] = var_xt_next.sum(dim=(1,2,3))
x_samples = model.decode_first_stage(xt_next) #
# var_xt_next = model.decode_first_stage(var_xt_next)# add
x = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
# os.makedirs(os.path.join(exp_dir, 'sam/'), exist_ok=True)
# for i in range(x.shape[0]):
# path = os.path.join(exp_dir, 'sam/', f"{img_id}.png")
# tvu.save_image(x.cpu()[i].float(), path)
# img_id += 1
sample_x.append(x)
var_x.append(var_xt_next) # add
sample_x = torch.concat(sample_x, dim=0)
var_x = torch.concat(var_x, dim=0)# add
var = []
for j in range(n_rounds):
var.append(var_sum[:, j])
var = torch.concat(var, dim=0)
sorted_var, sorted_indices = torch.sort(var, descending=True)
reordered_sample_x = torch.index_select(sample_x, dim=0, index=sorted_indices.int())
grid_sample_x = tvu.make_grid(reordered_sample_x, nrow=8, padding=2)
tvu.save_image(grid_sample_x.cpu().float(), os.path.join(exp_dir, "sorted_sample.png"))
print(f'Sampling {total_n_samples} images in {exp_dir}')
torch.save(var_sum.cpu(), os.path.join(exp_dir, 'var_sum.pt'))
var_x = var_x.mean(dim=1, keepdim=True) # add
reordered_var_x = torch.index_select(var_x, dim=0, index=sorted_indices.int()) # add
grid_var_x = tvu.make_grid(reordered_var_x, nrow=12, padding=1, normalize=True) # add
tvu.save_image(grid_var_x.cpu().float(), os.path.join(exp_dir, "sorted_var.png")) # addReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels