Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration):
if self.fix_batch is not None:
data_batch = misc.to(self.fix_batch, **model.tensor_kwargs)
tag = "ema" if self.is_ema else "reg"
raw_data, x0, condition = model.get_data_and_condition(data_batch)
raw_data, x0, condition = model.get_data_and_condition(data_batch)[:3]
if self.use_negative_prompt:
batch_size = x0.shape[0]
data_batch["neg_t5_text_embeddings"] = misc.to(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def generate_from_batch(
if self.model.config.text_encoder_config is not None and self.model.config.text_encoder_config.compute_online:
self.model.inplace_compute_text_embeddings_online(data_batch)

raw_data, x0, condition = self.model.get_data_and_condition(data_batch)
raw_data, x0, condition = self.model.get_data_and_condition(data_batch)[:3]

self.model.eval()
sample = self.model.generate_samples_from_batch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from cosmos_transfer2._src.predict2_multiview.models.multiview_vid2vid_model_rectified_flow import (
compute_empty_and_negative_text_embeddings,
inplace_compute_text_embeddings_online_multiview,
preprocess_databatch,
training_step_multiview,
)
from cosmos_transfer2._src.transfer2.models.vid2vid_model_control_vace_rectified_flow import (
Expand All @@ -61,8 +62,101 @@ class MultiviewControlVideo2WorldRectifiedFlowConfig(ControlVideo2WorldRectified
conditional_frames_probs: Optional[Dict[int, float]] = None # Probability distribution for conditional frames


def training_step_multiview_latent_mask(
Comment thread
spectralflight marked this conversation as resolved.
model, data_batch: dict[str, torch.Tensor], iteration: int
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""
Performs a single training step for the diffusion model.

This method is responsible for executing one iteration of the model's training. It involves:
1. Adding noise to the input data using the SDE process.
2. Passing the noisy data through the network to generate predictions.
3. Computing the loss based on the difference between the predictions and the original data, \
considering any configured loss weighting.

Args:
data_batch (dict): raw data batch draw from the training data loader.
iteration (int): Current iteration number.

Returns:
tuple: A tuple containing two elements:
- dict: additional data that used to debug / logging / callbacks
- Tensor: The computed loss for the training step as a PyTorch Tensor.

Raises:
AssertionError: If the class is conditional, \
but no number of classes is specified in the network configuration.

Notes:
- The method handles different types of conditioning
- The method also supports Kendall's loss
"""
model._update_train_stats(data_batch)

# only happens in training
data_batch = preprocess_databatch(data_batch, model.config.train_sample_views_range)

if model.config.text_encoder_config is not None and model.config.text_encoder_config.compute_online:
model.inplace_compute_text_embeddings_online(data_batch)

# Get the input data to noise and denoise~(image, video) and the corresponding conditioner.
_, x0_B_C_T_H_W, condition, latent_weights = model.get_data_and_condition(data_batch)

# Sample pertubation noise levels and N(0, 1) noises
epsilon_B_C_T_H_W = torch.randn(x0_B_C_T_H_W.size(), **model.tensor_kwargs_fp32)
batch_size = x0_B_C_T_H_W.size()[0]
t_B = model.rectified_flow.sample_train_time(batch_size).to(**model.tensor_kwargs_fp32)
t_B = rearrange(t_B, "b -> b 1") # add a dimension for T, all frames share the same sigma
x0_B_C_T_H_W, condition, epsilon_B_C_T_H_W, t_B = model.broadcast_split_for_model_parallelsim(
x0_B_C_T_H_W, condition, epsilon_B_C_T_H_W, t_B
)
timesteps = model.rectified_flow.get_discrete_timestamp(t_B, model.tensor_kwargs_fp32)
sigmas = model.rectified_flow.get_sigmas(
timesteps,
model.tensor_kwargs_fp32,
)
timesteps = rearrange(timesteps, "b -> b 1")
sigmas = rearrange(sigmas, "b -> b 1")
xt_B_C_T_H_W, vt_B_C_T_H_W = model.rectified_flow.get_interpolation(epsilon_B_C_T_H_W, x0_B_C_T_H_W, sigmas)

vt_pred_B_C_T_H_W = model.denoise(
noise=epsilon_B_C_T_H_W,
xt_B_C_T_H_W=xt_B_C_T_H_W.to(**model.tensor_kwargs),
timesteps_B_T=timesteps,
condition=condition,
)

time_weights_B = model.rectified_flow.train_time_weight(timesteps, model.tensor_kwargs_fp32)
per_instance_loss = torch.mean(
(model.hdmap_latent_weights_lambda * latent_weights + model.hdmap_latent_weights_beta)
* (vt_pred_B_C_T_H_W - vt_B_C_T_H_W) ** 2,
dim=list(range(1, vt_pred_B_C_T_H_W.dim())),
)

loss = torch.mean(time_weights_B * per_instance_loss)

output_batch = {
"x0": x0_B_C_T_H_W,
"xt": xt_B_C_T_H_W,
"sigma": sigmas,
"condition": condition,
"model_pred": vt_pred_B_C_T_H_W,
"edm_loss": loss,
}

return output_batch, loss


class MultiviewControlVideo2WorldModelRectifiedFlow(ControlVideo2WorldModelRectifiedFlow):
def __init__(self, config: MultiviewControlVideo2WorldRectifiedFlowConfig, *args, **kwargs):
def __init__(
self,
config: MultiviewControlVideo2WorldRectifiedFlowConfig,
*args,
enable_hdmap_latent_weights=False,
hdmap_latent_weights_lambda=1,
hdmap_latent_weights_beta=1,
**kwargs,
):
self.is_new_training = True
self.copy_weight_strategy = config.copy_weight_strategy
self.hint_keys = []
Expand All @@ -85,6 +179,10 @@ def __init__(self, config: MultiviewControlVideo2WorldRectifiedFlowConfig, *args
if self.config.text_encoder_config is not None and self.config.text_encoder_config.compute_online:
compute_empty_and_negative_text_embeddings(self)

self.enable_hdmap_latent_weights = enable_hdmap_latent_weights
self.hdmap_latent_weights_lambda = hdmap_latent_weights_lambda
self.hdmap_latent_weights_beta = hdmap_latent_weights_beta

@torch.no_grad()
def encode(self, state: torch.Tensor) -> torch.Tensor:
n_views = state.shape[2] // self.tokenizer.get_pixel_num_frames(self.state_t)
Expand Down Expand Up @@ -144,7 +242,10 @@ def decode_cp(self, latent: torch.Tensor) -> torch.Tensor:
def training_step(
self, data_batch: dict[str, torch.Tensor], iteration: int
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
return training_step_multiview(self, data_batch, iteration)
if self.enable_hdmap_latent_weights:
return training_step_multiview_latent_mask(self, data_batch, iteration)
else:
return training_step_multiview(self, data_batch, iteration)

def inplace_compute_text_embeddings_online(self, data_batch: dict[str, torch.Tensor]) -> None:
inplace_compute_text_embeddings_online_multiview(self, data_batch)
Expand Down Expand Up @@ -263,9 +364,11 @@ def get_data_and_condition(
latent_control_input = []
for hint_key in self.hint_keys:
control_input = getattr(condition, hint_key, None)

if control_input is not None:
raw_and_control_inputs.append(control_input)
if hint_key == "control_input_hdmap_bbox" and self.enable_hdmap_latent_weights:
latent_mask = (control_input.sum(dim=2, keepdim=True) > 0.12).float().mul(255).repeat(1, 1, 3, 1, 1)
raw_and_control_inputs.append(latent_mask)
num_modalities += 1
assert num_modalities > 0, "No control input found"
encoded_tensors = self._encode_raw_and_control_inputs(raw_and_control_inputs)
Expand Down Expand Up @@ -299,6 +402,9 @@ def get_data_and_condition(
latent_control_input=latent_control_input,
control_weight=data_batch.get(CONTROL_WEIGHT_KEY, 1.0),
)
if self.enable_hdmap_latent_weights:
latent_weights = encoded_tensors[-1]
return raw_state, latent_state, condition, latent_weights
return raw_state, latent_state, condition

def get_velocity_fn_from_batch(
Expand Down Expand Up @@ -340,7 +446,10 @@ def get_velocity_fn_from_batch(
is_image_batch = self.is_image_batch(data_batch_with_latent_view_indices)
condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO)
uncondition = uncondition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO)
_, x0, data_batch_condition = self.get_data_and_condition(data_batch_with_latent_view_indices)
if self.enable_hdmap_latent_weights:
_, x0, data_batch_condition, _ = self.get_data_and_condition(data_batch_with_latent_view_indices)
else:
_, x0, data_batch_condition = self.get_data_and_condition(data_batch_with_latent_view_indices)
# override condition with inference mode; num_conditional_frames used Here!
condition = condition.set_video_condition(
state_t=self.config.state_t,
Expand Down Expand Up @@ -442,7 +551,10 @@ def get_x0_fn_from_batch(
is_image_batch = self.is_image_batch(data_batch_with_latent_view_indices)
condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO)
uncondition = uncondition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO)
_, x0, data_batch_condition = self.get_data_and_condition(data_batch_with_latent_view_indices)
if self.enable_hdmap_latent_weights:
_, x0, data_batch_condition, _ = self.get_data_and_condition(data_batch_with_latent_view_indices)
else:
_, x0, data_batch_condition = self.get_data_and_condition(data_batch_with_latent_view_indices)
# override condition with inference mode; num_conditional_frames used Here!
condition = condition.set_video_condition(
state_t=self.config.state_t,
Expand Down