diff --git a/ChoreoAI_Zixuan_Wang/model/transformer.py b/ChoreoAI_Zixuan_Wang/model/transformer.py index 5c3356c..a7459c6 100644 --- a/ChoreoAI_Zixuan_Wang/model/transformer.py +++ b/ChoreoAI_Zixuan_Wang/model/transformer.py @@ -16,6 +16,7 @@ def __init__(self, linear_num_features, n_head, latent_dim): def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], -1) + # Fix: index by sequence length (dim 1), not batch (dim 0) x = self.pos_encoding(self.linear(x)) attn_output, _ = self.multihead_attention(x, x, x) _, (hidden, _) = self.lstm(attn_output) @@ -55,7 +56,7 @@ def sample_z(self, mean, log_var): batch, dim = mean.shape epsilon = torch.randn(batch, dim).to(self.device) return mean + torch.exp(0.5 * log_var) * epsilon - + def forward(self, x, no_input=False): batch = x.size()[0] if no_input: @@ -69,27 +70,152 @@ def forward(self, x, no_input=False): return x, mean, log_var +class VAEForDuetEncoder(nn.Module): + """ + Variant of VAEForSingleDancerEncoder whose input linear accepts + 29 * 9 features instead of 29 * 3, to accommodate the richer + proximity representation used by VAEForDuet. + Everything else (positional encoding, attention, LSTM, mean/log_var + heads) is identical to the single-dancer encoder. + """ + def __init__(self, linear_num_features, n_head, latent_dim, + proximity_dim: int = 29 * 9): + super(VAEForDuetEncoder, self).__init__() + self.linear = nn.Linear(proximity_dim, linear_num_features) + self.pos_encoding = PositionalEncoding(linear_num_features) + self.multihead_attention = nn.MultiheadAttention( + embed_dim=linear_num_features, num_heads=n_head, batch_first=True + ) + self.lstm = nn.LSTM( + input_size=linear_num_features, + hidden_size=linear_num_features, + num_layers=2, + batch_first=True, + ) + self.mean = nn.Linear(linear_num_features, latent_dim) + self.log_var = nn.Linear(linear_num_features, latent_dim) + + def forward(self, x): + # x: [B, T, 29*9] — already flattened by VAEForDuet.forward + x = self.pos_encoding(self.linear(x)) + attn_output, _ = self.multihead_attention(x, x, x) + _, (hidden, _) = self.lstm(attn_output) + return self.mean(hidden[-1]), self.log_var(hidden[-1]) + + class VAEForDuet(nn.Module): - def __init__(self, linear_num_features, n_head, latent_dim, n_units, seq_len, device='cuda'): + """ + Duet VAE with enriched proximity encoding. + + Instead of the single unsigned distance abs(d1 - d2) (29*3 features), + the proximity tensor now stacks three complementary signals per joint: + + unsigned_dist = |d1 - d2| direction-agnostic magnitude + signed_diff = d1 - d2 who is ahead/behind on each axis + contact_flag = (|d1-d2| < 0.1) binary near-contact indicator + + Concatenated along the joint-feature axis this gives 29 * 9 features. + The richer signal lets the duet encoder distinguish mirroring from + following, and detect near-contact moments that strongly shape + partner choreography. + + contact_threshold controls the Euclidean-per-axis distance below which + two joints are considered "in contact". Default 0.1 assumes unit-scale + joint positions; tune to your dataset's coordinate range. + """ + def __init__(self, linear_num_features, n_head, latent_dim, n_units, + seq_len, device='cuda', contact_threshold: float = 0.1): super(VAEForDuet, self).__init__() - self.encoder = VAEForSingleDancerEncoder(linear_num_features, n_head, latent_dim) + self.encoder = VAEForDuetEncoder(linear_num_features, n_head, latent_dim) self.decoder = VAEForSingleDancerDecoder(latent_dim, n_units, seq_len) self.device = device - + self.contact_threshold = contact_threshold + def sample_z(self, mean, log_var): batch, dim = mean.shape epsilon = torch.randn(batch, dim).to(self.device) return mean + torch.exp(0.5 * log_var) * epsilon - + def forward(self, d1, d2): - # proximity - proximity = torch.abs(d1 - d2) + # d1, d2: [B, T, 29, 3] + unsigned_dist = torch.abs(d1 - d2) # [B, T, 29, 3] + signed_diff = d1 - d2 # [B, T, 29, 3] + contact_flag = (unsigned_dist < self.contact_threshold).float() # [B, T, 29, 3] + + # Stack along the feature axis then flatten joints+features together + proximity = torch.cat([unsigned_dist, signed_diff, contact_flag], dim=-1) # [B, T, 29, 9] + proximity = proximity.reshape(proximity.shape[0], proximity.shape[1], -1) # [B, T, 261] + mean, log_var = self.encoder(proximity) z = self.sample_z(mean, log_var) x = self.decoder(z) return x, mean, log_var +class DuetFusionAttention(nn.Module): + """ + Cross-attention fusion module. + + Replaces the naive elementwise addition `memory = out_i + out_duet` + with a learned attention mechanism: + + memory_i = LayerNorm(out_i + CrossAttn(Q=out_i, K=out_duet, V=out_duet)) + + This allows the model to selectively pull relevant parts of the shared + duet representation into each dancer's motion stream, rather than + blending every dimension equally regardless of relevance. + + Because the VAE decoder outputs 29*3=87 features (not divisible by + typical head counts), the module projects inputs into an internal + attention dimension (attn_dim, default 64) that IS head-divisible, + runs attention there, and projects back to d_model. The residual is + added in the original d_model space so no information is lost. + + Args: + d_model: Feature dimension of the VAE decoder outputs (29*3). + n_head: Number of attention heads. Must divide attn_dim evenly. + attn_dim: Internal dimension for the attention computation. + Defaults to d_model rounded down to the nearest multiple + of n_head, so you never have to set this manually. + dropout: Dropout applied inside the attention layer. + """ + def __init__(self, d_model: int, n_head: int, attn_dim: int = None, dropout: float = 0.1): + super(DuetFusionAttention, self).__init__() + if attn_dim is None: + # Round down to nearest n_head multiple + attn_dim = (d_model // n_head) * n_head + self.proj_in = nn.Linear(d_model, attn_dim) + self.cross_attn = nn.MultiheadAttention( + embed_dim=attn_dim, + num_heads=n_head, + dropout=dropout, + batch_first=True, + ) + self.proj_out = nn.Linear(attn_dim, d_model) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, dancer_out: torch.Tensor, duet_out: torch.Tensor) -> torch.Tensor: + """ + Args: + dancer_out: [batch, seq_len, d_model] — individual VAE decoder output + duet_out: [batch, seq_len, d_model] — duet VAE decoder output + + Returns: + memory: [batch, seq_len, d_model] — fused representation + """ + q = self.proj_in(dancer_out) # project to attn_dim + k = self.proj_in(duet_out) + v = self.proj_in(duet_out) + + attn_out, _ = self.cross_attn(query=q, key=k, value=v) + attn_out = self.proj_out(attn_out) # project back to d_model + + # Residual connection + layer norm for training stability + memory = self.norm(dancer_out + self.dropout(attn_out)) + return memory + + class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEncoding, self).__init__() @@ -102,26 +228,31 @@ def __init__(self, d_model, max_len=5000): self.register_buffer('pe', pe) def forward(self, x): - x = x + self.pe[:x.size(0), :, :x.size(2)] + # Fix: use x.size(1) for sequence length when batch_first=True + x = x + self.pe[:x.size(1), :, :x.size(2)].transpose(0, 1) return x class TransformerDecoder(nn.Module): - def __init__(self, d_model, nhead=8, num_layers=2, dim_feedforward=256): + def __init__(self, d_model, nhead=8, num_layers=2, dim_feedforward=256, memory_dim=None): super(TransformerDecoder, self).__init__() - self.linear = nn.Linear(29 * 3, d_model) + # tgt always comes from raw joint positions (29*3) + self.tgt_linear = nn.Linear(29 * 3, d_model) + # memory comes from DuetFusionAttention whose output is n_units (= d_model here), + # but we keep a separate projection so the two paths are decoupled and + # memory_dim can differ from 29*3 without touching tgt_linear. + self.memory_linear = nn.Linear(memory_dim if memory_dim is not None else d_model, d_model) self.pos_encoder = PositionalEncoding(d_model) self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward) self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers) self.fc_out = nn.Linear(d_model, 29 * 3) def forward(self, tgt, memory): - # tgt: [29, 3], memory: [29 * 3] batch_size, seq_len, num_joints, joint_dim = tgt.shape tgt = tgt.view(batch_size, seq_len, num_joints * joint_dim) - tgt = self.linear(tgt) - memory = self.linear(memory) + tgt = self.tgt_linear(tgt) # [B, T, d_model] + memory = self.memory_linear(memory) # [B, T, d_model] tgt = self.pos_encoder(tgt) output = self.transformer_decoder(tgt, memory) @@ -136,10 +267,15 @@ def __init__(self, linear_num_features, n_head, latent_dim, n_units, seq_len, no self.vae_1 = VAEForSingleDancer(linear_num_features, n_head, latent_dim, n_units, seq_len, default_log_var) self.vae_2 = VAEForSingleDancer(linear_num_features, n_head, latent_dim, n_units, seq_len, default_log_var) self.vae_duet = VAEForDuet(linear_num_features, n_head, latent_dim, n_units, seq_len) - self.transformer_decoder_1 = TransformerDecoder(linear_num_features) - self.transformer_decoder_2 = TransformerDecoder(linear_num_features) + + # VAE decoders output [B, T, 29*3] — fusion d_model must match + self.fusion_1 = DuetFusionAttention(d_model=29 * 3, n_head=n_head) + self.fusion_2 = DuetFusionAttention(d_model=29 * 3, n_head=n_head) + + self.transformer_decoder_1 = TransformerDecoder(linear_num_features, memory_dim=29 * 3) + self.transformer_decoder_2 = TransformerDecoder(linear_num_features, memory_dim=29 * 3) self.no_input_prob = no_input_prob - + def forward(self, d1, d2, is_inference=False): rdm_val = torch.randn(1) is_simplified_model = False @@ -154,7 +290,6 @@ def forward(self, d1, d2, is_inference=False): if not is_inference and rdm_val < self.no_input_prob: is_simplified_model = True - # only focus on one VAE model out_1, mean_1, log_var_1 = self.vae_1(d1_normalized) out_2, mean_2, log_var_2 = self.vae_2(d2_normalized) batch_size, seq_len, _ = out_1.shape @@ -170,11 +305,14 @@ def forward(self, d1, d2, is_inference=False): out_2, mean_2, log_var_2 = self.vae_2(d2_normalized) out_duet, mean_duet, log_var_duet = self.vae_duet(d1_normalized, d2_normalized) - # [batch_size, seq_len, 29 * 3] - memory_1 = out_1 + out_duet - memory_2 = out_2 + out_duet + # Cross-attention fusion: + # dancer 1 queries the duet stream to build its context memory + # dancer 2 queries the duet stream independently + # Shape: [batch, seq_len, n_units] throughout + memory_1 = self.fusion_1(dancer_out=out_1, duet_out=out_duet) + memory_2 = self.fusion_2(dancer_out=out_2, duet_out=out_duet) - # transformer decoder + # Transformer decoder uses fused memory to predict the other dancer pred_2 = self.transformer_decoder_1(d2_normalized, memory_1) pred_1 = self.transformer_decoder_2(d1_normalized, memory_2) @@ -182,10 +320,20 @@ def forward(self, d1, d2, is_inference=False): if __name__ == '__main__': - model = DancerTransformer(64, 8, 32, 32, 64).to('cuda') - print(model) - input_1 = torch.rand(8, 64, 29, 3).to('cuda') - input_2 = torch.rand(8, 64, 29, 3).to('cuda') + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model = DancerTransformer( + linear_num_features=64, + n_head=8, + latent_dim=32, + n_units=64, + seq_len=64, + no_input_prob=0.2, + ).to(device) + + input_1 = torch.rand(8, 64, 29, 3).to(device) + input_2 = torch.rand(8, 64, 29, 3).to(device) - out_1, out_2, _, _, _, _, _, _ = model(input_1, input_2) - print(out_1.shape) + out = model(input_1, input_2) + pred_1, pred_2 = out[0], out[1] + print("pred_1 shape:", pred_1.shape) # expect [8, 64, 29, 3] + print("pred_2 shape:", pred_2.shape)